########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Author: Christopher Hart
Date  : April, 2002

These classes provide a framework for visualization within the MLS
python schema.  Briefly the work horse of the toolset are two primary
plotting classes IPlot and DatasetPlot built on the BLT toolkit
(although other plotting backend could probably be subsituted) .

IPlot is a specialization of the Pmw.Blt.Graph widget that adds may
user interface enhancements (pull down menus, sortable legends, popup
information, etc).

DatasetPlot is a specialization of IPlot which adds the communication
pathway for the PlotViews framework.  A DatasetPlot instance can only
plot when given a PlotView and they communicate via the interface
defined in the IPlotView class.  Through this interface custom
PlotViews can be built and combined to allow for a fairly module
interactive and extensible plotting enviroment.

"""

import math
import types
import colorsys
import re
import types
import operator
import sys

import Numeric
import MLab
import Tkinter
import Pmw

try:
  from Scientific.Statistics import Histogram as sciHistogram
except:
  sys.stderr.write("Scientific Python not found, plotting of Scientific Python's Histograms disabled\n")
  sciHistogram = types.NoneType


## FIXME: replace all occurances of direct use of class to module.class usage
from compClust.mlx.views import CachedView, RowAggregateFunctionView, AggregateFunctionView
from compClust.mlx.views import RowPCAView
from compClust.mlx.views import SortedView
from compClust.mlx.views import RowSubsetView
from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import GlobalWrapper
from compClust.mlx.labelings import subsetByLabeling

from compClust.score import roc
from compClust.score import ConfusionMatrix2

## soon this will replace roc

from compClust.util.InterpreterTools import safeStdDev
from compClust.util import DistanceMetrics
from compClust.util import Histogram
from compClust.util import listOps
from compClust.util import unique
from compClust.util import NaN


_ROOT = None

def startIPlot():
  """
  Simple Setup function to setup a psedo Tk-application for all widgets to belong to.
  """
  showSplash = 0
  global _ROOT

  if _ROOT is None:
    print "Starting IPlot..."
    _ROOT = Tkinter.Tk(className = 'IPlot')
    _ROOT.withdraw()

    if showSplash:
      import time
      splash = Tkinter.Toplevel()
      splash.title('Welcome to IPlot')
      text = Tkinter.Label(splash, 
                           font=('Helvetica', 16, 'bold'),
                           relief = 'raised',
                           text = 'Welcome to IPlot...'
                                  '\n\n'
                                  'Written By: \n'
                                  'Christopher Hart \n'
                                  'Caltech Biology'
                          )
      text.pack(fill = 'both', expand=1)
      _ROOT.update()
      splash.update_idletasks()
      splash.deiconify()
      _ROOT.update()
      time.sleep(1)
      splash.destroy()
  
class IPlot(Pmw.Blt.Graph):

  """
  This is simply a class that adds alot of user interface stuff to the standard BLT plot.

  """
  
  def __init__(self, parent=None, subplot=False, cnf={}, **kw):
    if parent is None: 
      global _ROOT
      if _ROOT is None:
        startIPlot()
      parent= Tkinter.Toplevel(master=_ROOT)
    try:
      Pmw.Blt._loadBlt(parent)
      Tkinter.Widget.__init__(self, parent, Pmw.Blt._graphCommand, cnf, kw)
    except:
      reload(Pmw.Blt)
      Pmw.Blt._loadBlt(parent)
      Tkinter.Widget.__init__(self, parent, Pmw.Blt._graphCommand, cnf, kw)
      
    self.__selectionMarker = self.marker_create('text', name='selection')
    
    # private variables to handle zooming issues.
    self.__zoomSelection = [None, None, None, None]
    self.__previousZooms =  []
    self.__dragging = None
   
    # build some default menus and make some default bindings.
    self.__defineMenus()
    self.bind('<ButtonRelease-3>', lambda event, menu=self.graphMenu: self.popupMenu( event, menu))

    # turn off grabbing detail plots for things that are subplots
    if not subplot:
      self.element_bind ('all', '<Button-1>', self.displayValues)
      self.element_bind('all', '<Control-Button-1>', self.displayValuesSummary)
    self.element_bind ('all', '<Double-Button-1>', self.elementSetup)
    self.legend_bind("all", "<Double-Button-1>", self.elementSetup)
    self.legend_bind("all", "<Button-1>", self.raiseElement)
    self.marker_bind('selection', '<Button-1>', self.hideSelection)
    
    self.bind(sequence="<Shift-ButtonPress-1>",   func=self.mouseDown)
    self.bind(sequence="<Shift-ButtonRelease-1>", func=self.zoomIn)
    self.bind(sequence="<Control-ButtonPress-1>", func=self.zoomOut)

  def __defineMenus(self):

    # the axis option
    self.axisMenu = Tkinter.Menu(self)
    self.axisMenu.add_command(label = 'x logscale', command = self.xlogScale)
    self.axisMenu.add_command(label = 'y logscale', command = self.ylogScale)
    self.axisMenu.add_command(label = 'descending', command = self.descending)
    #self.axisMenu.add_command(label = 'label axis', command = self.labelAxis)

    # general view options
    self.viewMenu = Tkinter.Menu(self)
    self.viewMenu.add_checkbutton(label = 'Toggle Crosshairs', command = self.showCrosshairs)
    self.viewMenu.add_checkbutton(label = 'Toggle Grid', command = self.showGrid)
    b = self.viewMenu.add_checkbutton(label = 'Toggle Legend', command =self.showLegend)
    b = self.viewMenu.add_checkbutton(label = 'Toggle Axis', command = self.showAxis)
    # file menu
    self.fileMenu = Tkinter.Menu(self)

    # options menu
    self.optionsMenu = Tkinter.Menu(self)
    self.optionsMenu.add_cascade(label='Axis', menu= self.axisMenu)

    # set up the main graph menu
    self.graphMenu = Tkinter.Menu(self)
    #self.graphMenu = Pmw.MenuBar(self)
    self.graphMenu.add_cascade(label = 'File', menu= self.fileMenu)
    self.graphMenu.add_cascade(label = 'View', menu= self.viewMenu)
    self.graphMenu.add_cascade(label = 'Settings', menu= self.optionsMenu)


  def popupMenu(self, event, menu):

    """
    popupMenu(self, event, menu)

    create a popup menu
    """
    
    sizex, sizey, x, y = map(int, re.split("[x+]", self.parent.geometry ()))
    menu.tk_popup(event.x+x, event.y+y)
  
  def __cBox(self, f, label, items):
    """
     Make a customized Combobox (used in graphSetup)
    """
    box = Pmw.ComboBox(f, label_text = label, 
                       labelpos = 'w', scrolledlist_items = items)
                
    box.pack(fill = 'both', expand = 1, padx = 8, pady = 8)
    return box

  def displayValues(self, event):
    """
    getValues(self, event)

    display an element's name, and the current value

    """

    element = event.widget.element_closest(event.x, event.y)
    if isinstance(element, types.DictType):
      x = element['x']
      y = element['y']
      name = element['name']
      status = "%s: (%3.2f, %3.2f)"%(name, x, y)
      coords = (self.xaxis_limits()[0], self.yaxis_limits()[1])
      self.marker_configure('selection',
                            coords= coords,
                            background="lightblue", 
                            text=status,
                            under=0,
                            anchor='w',
                            hide=0)

  def displayValuesSummary(self, event):
    """
    displayValuesSummary(self, event, plotView=None)

    display the SummaryView page for a data point
    """
    if self.currentPlotViews is None or len(self.currentPlotViews) == 0:
      return
    # FIXME: which one to grab?
    plotView = self.currentPlotViews[0]
    
    closest_widget = event.widget.element_closest(event.x, event.y)
    if closest_widget is None:
      return
    index = closest_widget['index']
    x = event.widget.element_closest (event.x, event.y)['x']
    y = event.widget.element_closest (event.x, event.y)['y']
    primaryName   = event.widget.element_closest (event.x, event.y)['name'].split('__')[0]

    primaryLabeling   = plotView.getAnnotationMapper().getPrimaryLabeling()
    secondaryLabeling = plotView.getAnnotationMapper().getSecondaryLabeling()
    if primaryLabeling is None:
      row = int(primaryName)
    else:
      row = primaryLabeling.getRowsByLabel(primaryName)[0]
    if secondaryLabeling is not None:
      try:
        secondaryName = reduce(lambda x,y: str(x)+str(y), secondaryLabeling.getLabelsByRow(row))
      except:
        pass
    else:
      secondaryName = ''

    SummaryWindow(plotView, row)

    
  def hideSelection(self, event):

    """
    hideDisplay(self, event)
    """

    # this hack makes the labels go away before a plot referesh
    self.marker_configure('selection',  coord=(-100000,-1000000))
    self.marker_configure('selection', under=1, hide=1)
    
  def raiseElement(self, event):

    """
    raiseElement(event)
    
    Brings the selected data point to the front
    
    """
    displayedElements = list(self.element_show())
  
    pos = "@" +str(event.x) +"," +str(event.y)
    selection = self.legend_get(pos)    # get the selected legend
    selectedIndex = displayedElements.index(selection)
    
    if displayedElements[-1] == selection:
      # selection is at top, bring to bottom
      displayedElements.pop(-1)
      displayedElements.insert(0, selection)
    else:
      # selection is somewhere other than the top
      displayedElements.pop(selectedIndex)
      displayedElements.insert(len(displayedElements), selection)
      self.element_show(tuple(displayedElements))

  def setTitles(self, event=None):
    """
    create a dialogue box to edit the Graph title, and axis labels
    """

    print "not yet implimented"

  def elementSetup(self, event):
    """
    Make a dialog, and ask for a specified graph's color, linewidth etc.
     This function is called when the graph is double-clicked.
    """

    def applyChanges(event):
      self.element_configure(elName, color=colBox.get(), symbol=symBox.get(), 
                             smooth=smtBox.get(), linewidth=linBox.get(),
                             fill=scoBox.get(), outline=solBox.get())

    try:
      el = self.element_closest(event.x, event.y, interpolate=1)
      elName = el["name"]
    except:
      pos = "@" +str(event.x) +"," +str(event.y)
      elName = self.legend_get(pos)    # get the selected legend

    dialog = Pmw.Dialog(self.parent)

    applyButton = Tkinter.Button(dialog.interior(), text = 'Apply',
                                 command = applyChanges)
    
    dialog.configure(
      buttons = ('OK','Apply'),
      title = 'Edit Line Properites - %s'%(elName),
      command = dialog.deactivate)

    dialog.withdraw()
    f = Tkinter.Frame(dialog.interior())
    f.pack()
      
    colBox = self.__cBox(f, "Color:",
                         ('red', 'yellow', 'blue', 'green', 'black', 'grey', 'custom'))
    symBox = self.__cBox(f, 'Symbols:',
                         ("", "square", "circle", "diamond", "cross", "triangle"))
    scoBox = self.__cBox(f, 'Symbol color:',
                         ('defcolor', 'red', 'yellow', 'blue', 'green', 'black', 'custom'))
    solBox = self.__cBox(f, 'Symbol outline:',
                         ('defcolor', 'red', 'yellow', 'blue', 'green', 'black', 'custom'))
    smtBox = self.__cBox(f, 'Smootheness:',
                         ('step', 'linear', 'quadratic', 'natural'))
    linBox = self.__cBox(f, 'Line thickness:', (0, 1, 2, 3, 4, 5))
    
                    
    # Retrieve the current setup for the graph...
    #colBox.selectitem(self.element_cget(elName, "color"))
    try:
      symBox.selectitem(self.element_cget(elName, "symbol"))
      scoBox.selectitem(self.element_cget(elName, "fill"))
      solBox.selectitem(self.element_cget(elName, "outline"))
      smtBox.selectitem(self.element_cget(elName, "smooth"))
      linBox.selectitem(self.element_cget(elName, "linewidth"))
    except IndexError, e:
      # I don't know why linewidth sometimes returns an index error
      # but I think i can just ignore it since it's happening on
      # summary plots like the confusion matrix
      return
    
    # Let the user interact
    dialog.activate()
    
    # Update any changes
    self.element_configure(elName, color=colBox.get(), symbol=symBox.get(), 
                           smooth=smtBox.get(), linewidth=linBox.get(),
                           fill=scoBox.get(), outline=solBox.get())
   
  # The next functions configure the axes
  def showAxis(self): 
    state = int(self.axis_cget("x", 'hide'))
    self.axis_configure(["x", "y"], hide = not state)
    
  def xlogScale(self):
    state = int(self.xaxis_cget('logscale'))
    self.xaxis_configure(logscale = not state)
    
  def ylogScale(self):
    state = int(self.yaxis_cget('logscale'))
    self.yaxis_configure(logscale = not state)
    
  def descending(self):
    state = int(self.axis_cget("x", 'descending'))
    self.axis_configure(["x", "y"], descending = not state)

  # The next functions configures the Crosshairs
  def mouseMove(self, event):
    self.crosshairs_configure(position="@" +str(event.x) +","+str(event.y))
    
    
  def showCrosshairs(self):
    hide = not int(self.crosshairs_cget('hide'))
    self.crosshairs_configure(hide = hide, dashes="1")
    if(hide):
      self.unbind("<Motion>")
    else:
      self.bind("<Motion>", self.mouseMove)
       
  # The next functions configures the Grid
  def showGrid(self):
    self.grid_toggle()
    
  # The next functions configures the Legend
  def showLegend(self):
    state = int(self.legend_cget('hide'))
    self.legend_configure(hide = not state)

  # these functions are for zoooming

  def zoom(self):
    x0, y0, x1, y1 = self.__zoomSelection
    self.xaxis_configure(min=x0, max=x1)
    self.yaxis_configure(min=y0, max=y1)

  def mouseDrag(self, event):
    x0, y0, x1, y1 = self.__zoomSelection
    (x1, y1) = self.invtransform(event.x, event.y)
    self.marker_configure("marking rectangle", 
                          coords = (x0, y0, x1, y0, x1, y1, x0, y1, x0, y0))
    self.__zoomSelection = [x0,y0,x1, y1]

  def zoomIn(self, event):

    x0, y0, x1, y1 = self.__zoomSelection
    
    if self.__dragging:
      self.unbind(sequence="<Motion>")
      self.marker_delete("marking rectangle")
      
      if x0 <> x1 and y0 <> y1:

        # make sure the coordinates are sorted
        if x0 > x1:
          x0, x1 = x1, x0
        if y0 > y1:
          y0, y1 = y1, y0

        print "zoom in"
        self.__zoomSelection = [x0,y0,x1, y1]
        self.__previousZooms.append([self.xaxis_limits()[0],
                                     self.yaxis_limits()[0],
                                     self.xaxis_limits()[1],
                                     self.yaxis_limits()[1]])
        
        self.zoom()

  def zoomOut(self, event):

    """
    zoomOut (self, event)

    zooms out to the previous zoom state - can only zoom as far as the first zoom state.
    """
    if len(self.__previousZooms) > 0:
      self.__zoomSelection = self.__previousZooms.pop()
      self.zoom()

  def mouseDown(self, event):
    x0, y0, x1, y1 = self.__zoomSelection
    self.__dragging = 0
    if self.inside(event.x, event.y):
      (x0, y0) = self.invtransform(event.x, event.y)
      self.__zoomSelection = [x0,y0,x1, y1]
      self.__dragging = 1

    self.__zoomSelection = [x0,y0,x1, y1]
    self.marker_create("line", name="marking rectangle", dashes=(1, 2))
    self.bind(sequence="<Motion>",  func=self.mouseDrag)

  def clear(self):

    """
    clears the plot
    """
    for element in self.element_show():
      self.element_delete(element)

class SummaryWindow:

  """
  SummaryWindow

  A nice simple summary window for DatasetPlot graphs

  """

  def __init__(self, plotView, row, xdata=None, parent=None):

    self.dialog = Pmw.Dialog(parent, hull_width=640, hull_height=480)
    # set up subframes
    self.dataset = plotView.getDataset()
    self.plotView = plotView
    if xdata is None:
      self.xdata = tuple(range(self.dataset.getNumCols()))
    else:
      self.xdata = tuple(xdata)

    self.createPanes()
    self.addRow(row)
    self.parent = parent

  def createPanes(self):

    # setup all the information frames.

    # this is the base pane with 2 components (a top and bottom)
    self.panes = Pmw.PanedWidget(self.dialog.interior())
    self.panes.add('top')
    self.panes.add('bottom', min=100)

    # now we split the top into a left and right
    self.topPanes = Pmw.PanedWidget(self.panes.pane('top'),
                                    orient='horizontal')
    self.topPanes.add('left', size=300)
    self.topPanes.add('right')

    # now we split the bottom into a left and right
    self.bottomPanes = Pmw.PanedWidget(self.panes.pane('bottom'),
                                       orient='horizontal')
    self.bottomPanes.add('left', min=100, size=300)
    self.bottomPanes.add('right')
    
    # create the plot in the upper left corner
    self.plot = IPlot(self.topPanes.pane('left'))
    self.plot.pack(fill='both', expand=1)
    
    # create the labeling listing in the lower left corner
    self.listBox  = Pmw.ScrolledListBox(self.bottomPanes.pane('left'))
    self.listBox.pack(expand=1, fill='both')

    # this is a list of OptionMenus
    self.labelActions = []
    
    self.topPanes.pack(expand=1, fill='both')
    self.bottomPanes.pack(expand=1, fill='both')
    self.panes.pack(expand=1, fill='both')

  def __listSelection(self, row):

    """
    __listSelection(self, event)

    """
    sels = self.listBox.getcurselection()
    if len(sels) > 0:
      #sels[0].split(':')[1].split(',')[0].strip())
      labelings = self.dataset.getLabelings()
      index = map(str, labelings).index(sels[0])
      labeling = labelings[index]
      self.__populateLabelActions(labeling, row)
    
  def __populateLabelActions(self, labeling, row):

    """
    __populateCheckButtons(self, labeling, label)

    """

    map(lambda x: x.destroy(), self.labelActions)
    self.labelActions = []
    for lab in labeling.getLabelsByRow(row):   
      menu = Pmw.OptionMenu (self.bottomPanes.pane('right'),
                             labelpos = 'w',
                             label_text = str(lab),
                             items = ('default display',
                                      'Highlight This Group',
                                      'Show Only This Group',
                                      'Hide This Group'),
                             initialitem = 'default display',
                             menubutton_width = 10)
                             
      
      menu.configure(command=lambda selection, labeling=labeling, label=lab:
                     self._groupAction(selection, labeling, label))

      menu.pack(expand=1, fill='x')
      self.labelActions.append(menu)
      
    Pmw.alignlabels(self.labelActions)

  def _groupAction(self, sel, labeling, label):

    """
    _groupAction (self, sel, labeling, label)

    """
    primaryLabeling = self.plotView.getAnnotationMapper ().getPrimaryLabeling()
    plot = self.plotView._getPlot()

    if primaryLabeling is None:
      elements = map(str, labeling.getRowsByLabel(label))
    else:
      elements = map(lambda x: primaryLabeling.getLabelsByRow(x)[0],
                     labeling.getRowsByLabel(label))
      
    if sel ==  'Highlight This Group':
      map(plot.element_activate, elements)

    elif sel == 'Show Only This Group':
      map(lambda x: plot.element_configure(x, hide=1),
          (listOps.difference(plot.element_names(), elements).items()))
          

      
    elif sel == 'Hide This Group':
      map(lambda x: plot.element_configure(x, hide=1), elements)

    else:
      map(lambda x: plot.element_configure(x, hide=0), plot.element_names())
      map(plot.element_deactivate, plot.element_names())
      
  def __populateListBox(self, row):

    """
    fill the list box with all the applicable labelings.
    """

    labelings = tuple(map(str, filter(lambda x: len(x.getLabelsByRow(row))>0,
                                      self.dataset.getLabelings())))

    
    self.listBox.setlist(labelings)
    self.listBox.configure(selectioncommand=lambda row=row: self.__listSelection(row))

  def __cBox(self, f, label, items):

    """
    Make a customized Combobox 
    """

    box = Pmw.ComboBox(f, label_text = label, 
                       labelpos = 'w', scrolledlist_items = items)
    
    box.pack(fill = 'both', expand = 1, padx = 8, pady = 8)
    return box
    
  def addRow(self, row):

    """
    addRow(self,row)
    """
    primaryLabeling = self.plotView.getAnnotationMapper().getPrimaryLabeling()
    secondaryLabeling = self.plotView.getAnnotationMapper().getPrimaryLabeling()
    if primaryLabeling is None:
      label = str(row)
    else:
      label = primaryLabeling.getLabelsByRow(row)[0]
    
    self.plot.line_create(label,
                          ydata=tuple(self.dataset.getRowData(row)),
                          xdata=self.xdata)

    if secondaryLabeling is not None:
      self.plot.configure(title=secondaryLabeling.getLabelsByRow(row)[0])
    else:
      self.plot.configure(title=label)

    self.__populateListBox(row)

  def clear(self):
    map(self.plot.element_delete, self.plot.element_show())
    
class IPlotView:

  """
  IPlotView is a special compClust.View which impliments data getters
  which such that any IPlotView can be plugged into the  def __init__(self


  DatasetPlot.plot function. 
  """

  def __init__(self, dataset):

    """
    Initializes the Plot View.
    """

    # is a transient pointer to the plot which
    # is using the this PlotView
    self.__plot = None
    self.__dataset = dataset

    # these up the criticle mappers.  Only the dataMapper is required.
    self._dataMapper       = None 
    self._colorMapper      = None  
    self._markersMapper    = None
    self._annotationMapper = None
    self._bindingsMapper   = None

    #try:

    # added by sagar --- important variable to have for plotting utility
    try:
      self.__dataset.IPlotViews
    except:
      self.__dataset.IPlotViews = []
      
    #self.__dataset.IPlotViews.append(self)
    #except:
    #  pass

  def _setPlot(self, plot):
    """
    setPlot(self)

    sets the trainsient plot pointer
    """
    self.__plot = plot
    
  def _getPlot(self):
    """
    setPlot(self)

    sets the trainsient plot pointer
    """

    return(self.__plot)

  def getDataset(self):

    return(self.__dataset)

  def getDataMapper(self):

    return(self._dataMapper)

  def getColorMapper(self):

    return(self._colorMapper)
    
  def getMarkersMapper(self):

    return(self._markersMapper)
    
  def getAnnotationMapper(self):

    return(self._annotationMapper)

  def getBindingsMapper(self):

    return(self._bindingsMapper)

  def plotSetup(self):

    """
    plotSetup(self)

    this function gets called before any plotting begins. 
    """

    pass

  def plotFinishings(self):
    """
    plotEnd(self)

    this function get called last in the plot function.
    """

    pass

  def getKWArgs (self):

    """
    addKWargs(self)

    This is ment to allow for a complete pass through to the BLT
    options configuration.  It can either be none, or a list of tuples
    for each element being plotted
    """

    return(None)
    
  def refresh(self):

    """
    refresh

    signals that the view has changed and the plot should update.
    """

    if self.__plot is not None:
      self.__plot.plot()

class Mapper:

  """
  Base class for PlotView Mapper Classes.
  """

  def __init__(self, plotView):
    self.__plotView = plotView

  def getPlotView(self):
    return(self.__plotView)

class IDataMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how a dataset is projected onto the graph
  """

  def getPlotData(self):

    """
    plotData(self)
    
    returns a list of X,Y tuple Tuples.

    example: 
    [ ((x1,x2,x3,x4), (y1,y2,y3,y4))
      ((x1,x2,x3,x4), (y1,y2,y3,y4)) ]
    
    """
    raise NotImplementedError()
  
  def getLineTrace(self):
    """
    getLineStyle(self)
    
    retuns how the points when be connected...  If None, then the BLT
    default line connectors are used.

    Must be:
      'increasing', 'decreasing', 'both' or  None
    """

    return(None)
  
  def getWeightsAndSytles(self):

    """
    getWeights (self)

    This allows for marker control on a datapoint by datapoint
    basis...  If you want control on an element by element basis use
    getPlotMarkes and getMarkerSize functions inside of the
    MarkersMappers.
    
    returns either None or a list of tuples equal in length to the
    plotData returns value where the tuple is the weights tuple and
    the cooresponding styles tuple (see the BLT documentation for more
    information (this is very powerful)
    """

    return(None)

class ColorMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how the elements in the view should be colored.  
  """

  def getColors(self, colorModel='hsv'):
    """

    getColors(self)

    returns:
      None for monocolor plots (all elements will be drawn in the same color)
      or
      List of HSV tuples with length = numRows
      ie [(.1,1,1), (.2, 1,1), ...]

    see RowColorMapper for more details. 
    """

    return (None)

class MarkersMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how the element markers and connecting lines should be plotted.

  """

  def getMarkerColors(self):
    """
    """

    return(None)

  def getPlotMarkers(self):
    """
    getPlotMarkers

    returns a list of markers to be added to the plot...  These are
    overriden by the weights/styles """

    return(None)

  def getMarkerSizes(self):
    """
    getMarkerSize

    this function returns the desired marker size.  These are
    overriden by the weights/styes.  """

    return( "0.04")
  
class AnnotationMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how the elements should be annotated.
  """

  
  def getPrimaryLabeling(self):

    """
    getPrimaryAnnoations(self)

    returns a None or a labeling to be used as the primary element
    annotations - element names... these must each be unique.
    
    """

    return(None)

  def getSecondaryLabeling(self):

    """
    getSecondaryAnnoations(self)

    returns a None or a labeling to be used as the secondary element
    annotations - perhaps long discriptions
    
    """
    return(None)

class BindingsMapper(Mapper):

  def addBindings(self):

    """
    addBindings(self)

    This function if implimented adds novel event bindings to the
    plotView after the plot is generated.  The IPlot graph class has a
    few default bindings which will be left, unless overriden.

    """

    pass
    
###
#
# Implimentations of the Above Interfaces.
#
########

class SimpleBindingsMapper(BindingsMapper):

  """
  Provides some convient acces to RowAnnotations
  """

  def addBindings(self):

    self.__plot = self.getPlotView()._getPlot()
    self.__annotationMapper = self.getPlotView().getAnnotationMapper()
    self.__plot.element_bind('all', '<Button-1>', self.displayAnnotations)
    self.__activeElements = {}

  def displayAnnotations(self, event):

    """
    displayAnnotations(self, event)

    Overides the IPlot display information to display both the primary
    and secondary information in the upper right hand corner.

    """

    closest_widget = event.widget.element_closest(event.x, event.y)
    if closest_widget is None:
      return
    index = closest_widget['index']
    x = event.widget.element_closest (event.x, event.y)['x']
    y = event.widget.element_closest (event.x, event.y)['y']
    primaryName   = event.widget.element_closest (event.x, event.y)['name'].split('__')[0]

    primaryLabeling   = self.__annotationMapper.getPrimaryLabeling()
    secondaryLabeling = self.__annotationMapper.getSecondaryLabeling()
    if primaryLabeling is None:
      row = int(primaryName)
    else:
      row = primaryLabeling.getRowsByLabel(primaryName)[0]
    if secondaryLabeling is not None:
      try:
        secondaryName = reduce(lambda x,y: str(x)+str(y), secondaryLabeling.getLabelsByRow(row))
      except:
        pass
    else:
      secondaryName = ''

    status = "%s: (%3.2f, %3.2f) -- %s "%(primaryName.strip(), x, y, secondaryName.strip())

    coords = (self.__plot.xaxis_limits()[0], self.__plot.yaxis_limits()[1])
    self.__plot.marker_configure('selection', coords= coords, background="lightblue", text=status, anchor='w', under=0, hide=0)
    #self.__plot.marker_bind('selection', '<Control-Button-1>', lambda ev, row=row: self.displayLabelings(ev, row))
    #self.__plot.marker_bind('selection', '<Control-Button-1>', lambda ev, row=row: SummaryWindow(self.getPlotView(), row, parent=self.__plot.parent))
    self.__plot.marker_bind('selection', '<Control-Button-1>', lambda ev, row=row: SummaryWindow(self.getPlotView(), row))    
    self.__plot.marker_bind('selection', '<Button-1>', self.hideSelection)

  def hideSelection(self, event):

    """
    hideDisplay(self, event)
    """

    # this hack makes the labels go away before a plot referesh
    self.__plot.marker_configure('selection',  coord=(-100000,-1000000))
    self.__plot.marker_configure('selection', under=1, hide=1)

  def displayLabelings(self, event, row):

    """
    displayLabelings(self, event)

    This little guy pops up a dialog with all the labelings attached
    to the dataset and what that point has been labeled as
    """

    ds = self.getPlotView().getDataset()
    primaryLabeling = self.__annotationMapper.getPrimaryLabeling()
    if primaryLabeling is None:
      name = str(row)
    else:
      name = primaryLabeling.getLabelsByRow(row)[0]

    # set up a new scrolled window for the display
    win = Tkinter.Toplevel()
    sf = Pmw.ScrolledFrame(win,
                           labelpos = 'n', label_text = 'Labels for %s'%(name),
                           usehullsize = 1,
                           hull_width = 700,
                           hull_height = 500,
                           horizflex= 'expand',
                           vertflex = 'expand'
                           )
    sf.pack(fill='both', expand=1)
    frame = sf.interior()
    gridRow = 0
    for labeling in ds.getLabelings():
      tkLabel = Tkinter.Label(frame, text=str(labeling), anchor='e')
      tkLabel.grid(row = gridRow, col = 0)
      labels = labeling.getLabelsByRow(row)
      gridCol = 1
      for label in labels:
        # this block is a little dirty, but attaches the button state to each button and makes sure that
        # the callback is notified of the current button state.
        v = Tkinter.IntVar()
        button = Tkinter.Checkbutton(frame,
                                     text=str(label),
                                     anchor='e',
                                     variable = v)
        button.var = v
        button.configure(command=lambda labeling=labeling, label=label, state=button.var:
                         self.activateMembers(labeling, label, state))

        button.grid(row=gridRow, column= gridCol)
        gridCol +=1
      gridRow +=1

  def activateMembers(self, labeling, label, state):

    """
    activateMembers(self, labeling, label, button)

    using the BLT selection feature, toggle the selection of all
    elements which are marked by the label in labeling.
    """

    plot = self.getPlotView()._getPlot()
    rows = labeling.getRowsByLabel(label)
    primaryLabeling   = self.__annotationMapper.getPrimaryLabeling()
    if primaryLabeling is None:
      elements = rows
    else:
      elements = map(lambda x:primaryLabeling.getLabelsByRow(x)[0], rows)

    if state.get() == 1:
      for element in elements:
        plot.element_activate(element)
        if self.__activeElements.has_key(element):
          self.__activeElements[element] += 1
        else:
          self.__activeElements[element] = 1

    else:
      for element in elements:
        if self.__activeElements.get(element) <= 1:
          plot.element_deactivate(element)
          del(self.__activeElements[element])
        else:
          self.__activeElements[element] -=  1         
        
class RowAnnotationMapper(AnnotationMapper):

  """
  provides a mapping of the primary and secondary annotations to
  labelings.

  """

  def __init__(self, plotView, labeling1 = None, labeling2 = None):

    """
    __init__(self, labeling1 = None, labeling2 = None)

    """
    AnnotationMapper.__init__(self, plotView)
    self.__lab1 = None
    self.__lab2 = None
    if labeling1:
      self.__lab1 = GlobalWrapper(plotView.getDataset(), glabeling=labeling1)
    if labeling2:
      self.__lab2 = GlobalWrapper(plotView.getDataset(), glabeling=labeling2)

  def getPrimaryLabeling(self):

    """
    getPrimaryAnnoations(self)

    returns a None or a labeling to be used as the primary element
    annotations - element names... these must each be unique.
    
    """

    return(self.__lab1)

  def getSecondaryLabeling(self):

    """
    getSecondaryAnnoations(self)

    returns a None or a labeling to be used as the secondary element
    annotations - perhaps long discriptions
    
    """

    return(self.__lab2)

  def setPrimaryLabeling(self, labeling):

    """
    setPrimaryLabeling(self, labeling):
    
    This labeling must be unique.
    """

    self.__lab1 = labeling

  def setSecondaryLabeling(self, labeling):

    """
    setSecondaryLabeling(self, labeling):
    
    """

    self.__lab2 = labeling

class RowMarkersMapper(MarkersMapper):

  symbols = ["plus", "square",  "diamond",  "cross", "splus", "scross", "triangle",  "circle"] 
  def __init__(self, plotView):
    MarkersMapper.__init__(self, plotView)
    self.dataset = self.getPlotView().getDataset()
    self.__maxSize = .5
    self.__minSize = .01
    self.__sizes   = 0.04
    self.__markers = RowMarkersMapper.symbols[-1]

  def setPlotMarkers(self, marker='cicle'):
    if marker in [symbols]:
      self.__markers = [marker]
    else:
      print "unrecognized marker"
      
  def setMaxSize(self, max):
    self.__maxSize = max
    if type(self.__sizes) in [types.ListType, types.TupleType, Numeric.ArrayType]:
      self.__sizes = scaleList(self.__sizes, minReturn=self.__minSize, maxReturn = self.__maxSize)
    else:
      if self.__sizes > self.__maxSize:
        self.__sizes = self.__maxSize
    
  def setMinSize(self, min):
    self.__minSize = min
    if type(self.__sizes) in [types.ListType, types.TupleType, Numeric.ArrayType]:
      self.__sizes = scaleList(self.__sizes, minReturn=self.__minSize, maxReturn = self.__maxSize)
    else:
      if self.__sizes < self.__minSize:
        self.__sizes = self.__minSize
    
  def getMarkerSizes(self):
    """
    getMarkerSize

    this function returns the desired marker size.  These are
    overriden by the weights/styes.  """

    return(self.__sizes)
  
  def setMarkerSizeByLabeling(self, labeling, minValue=None, maxValue=None):
    self.__sizes =  scaleList(labeling.getLabelByRows(), minValue, maxValue, self.__minSize, self.__maxSize)
    
  def setMarkerSizeByColValue(self, col, minValue=None, maxValue=None):
    data = self.dataset.getColData(col)
    if maxValue is None:
      maxValue = MLab.max(data)
    if minValue is None:
      minValue = MLab.min(data)
    self.__sizes= scaleList(data, minValue, maxValue, self.__minSize, self.__maxSize)

  def setMarkerSizeByIndex(self, minValue=None, maxValue=None):
    self.__sizes = scaleList(range(self.dataset.getNumRows()), minValue, maxValue, self.__minSize, self.__maxSize)

  def setMarkerSizeByFunction(self,function, minValue=None, maxValue=None):
    """
    sets the maker size by the given function...  the function prototype should be:
    value = function(dataset, row)
    """
    data = map(lambda row: function(self.dataset, row), range(self.dataset.getNumRows()))
    self.__sizes = scaleList(data, minValue, maxValue, self.__minSize, self.__maxSize)

  def invertMarkerSizes(self):

    if type(self.__sizes) in [types.ListType, Numeric.ArrayType, types.TupleType]:
      numRows = len(self.__sizes)
      max = MLab.max(self.__sizes)
      min = MLab.min(self.__sizes)
      self.__sizes = map(operator.sub, [max]*numRows, self.__sizes)

  def setUniformMarkerSize(self, size= .04):
    self.__sizes = size

  def clearSizes(self):
    self.__sizes = None
  
  def getPlotMarkers(self):
    """
    getPlotMarkers

    returns a list of markers to be added to the plot...  These are
    overriden by the weights/styles """

    return(self.__markers)

  def setPlotMarkersByLabeling(self, labeling, labels=None):

    """ sets the plot markers according to labeling.  If optional
    labels are supplied the marker shapes set only for the specified
    labels.  Note that there are only 8 unique marker styles, so if
    more than 8 labels are listed (or contained inside of labeling)
    symbols will cycle.
    """

    numSymbols = len(RowMarkersMapper.symbols)
    numRows = self.dataset.getNumRows()

    if labels is None:
      labels = listOps.unique(labeling.getLabelByRows())

    if len(labels) > numSymbols:
      symbols = RowMarkersMapper.symbols*int(MLab.ceil(len(labels)/numSymbols))
    else:
      symbols = RowMarkersMapper.symbols
    
    labelToSymbol = {}
    map(labelToSymbol.setdefault, labels, symbols[:len(labels)])

    self.__markers = map(labelToSymbol.get, labeling.getLabelByRows(), [RowMarkersMapper.symbols[-1]]*numRows)
    return(labelToSymbol)

  def setUniformPlotMarker(self, marker='circle'):

    """
    sets the marker symbol to be the same for all points.

     are the availble shapes
     'plus', 'square',  'diamond',  'cross', 'splus', 'scross', 'triangle',  'circle' and 'none'
    """
    self.__markers = marker

class RowColorMapper(ColorMapper):

  """

  This Class provides a very flexible dynamically adjustable color
  mapper.  Several methods are supplied to set colors by, in essence
  HSV color parameters can be mapped to several attributes of a
  dataset/labeling

  """

  def __init__(self, plotView):

    """
    This constructor does little more than call IPlotView's constructor
    """

    ColorMapper.__init__(self, plotView)
    self.dataset = self.getPlotView().getDataset()
    
    # these are the HSV lists that are used to color plots with.  
    self.__h = None
    self.__s = None
    self.__v = None

  def setColorByLabelingCounts(self, labeling, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):

    """ sets the given color component's value proportional to the
    size of the class
    """

    sizeDict = {}
    map(lambda x: sizeDict.setdefault(x, len(labeling.getRowsByLabel(x))),  labeling.getLabels())
    values = map(sizeDict.get,  map(labeling.getLabelByRow, range(self.dataset.getNumRows())), [0]*self.dataset.getNumRows())
    values = scaleList(values, minValue, maxValue, colorRange[0], colorRange[1])
    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values
                               
  def setColorByLabeling(self, labeling, labels=None, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):
  
    """
    setColorByLabeling(self, labeling=None, component='h', minValue=None, maxValue=None, colorRange(0,.7))

    uses the labeling to produces a color list for plotting.  If the
    labeling contains real values, then the optional parameters
    minValue and maxValue can be used to scale the colorRange used.
    Only uses the first label attached to every row

    """
    if labels is None:
      rowLabels = labeling.getLabelByRows()
    else:
      tmp = {}
      map(tmp.setdefault, labels, labels)
      rowLabels = map(tmp.get, labeling.getLabelByRows(), ['__background__']*self.dataset.getNumRows())

    values = scaleList(rowLabels, minValue, maxValue, colorRange[0], colorRange[1])    

    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values

  def setColorByTupleLabeling(self, labeling, element=0, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):
  
    """
    setColorByTupleLabeling(self, labeling=None, component='h', minValue=None, maxValue=None, colorRange(0,.7))

    uses the labeling to produces a color list for plotting.  If the
    labeling contains real values, then the optional parameters
    minValue and maxValue can be used to scale the colorRange used.
    Only uses the first label attached to every row

    """
    values = scaleList(map(lambda x:x[element], labeling.getLabelByRows()), minValue, maxValue, colorRange[0], colorRange[1])

    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values
      
  def setColorByColValue(self, col, component = 'h', minValue=None, maxValue=None, colorRange=(0,.7)):
    
    """
    setHueByColValue(self,col, hRange(0,1))

    sets the color to be a function of the value for that row at a given column.
    """

    data = self.dataset.getColData(col)
    if maxValue is None:
      maxValue = MLab.max(data)
    if minValue is None:
      minValue = MLab.min(data)

    values = scaleList(data, minValue, maxValue, colorRange[0], colorRange[1])
    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values

  def setColorByFunction(self, function, component='h',  minValue=None, maxValue=None, colorRange=(0,.7)):

    """
    setColorByFunction(self, funct, component='h', minValue, maxValue, colorRange(0,.7))

    function should take in a dataset and a row and return a single
    float.  Colors will be scaled based on that value.
    
    function prototype should be:

    value = function(dataset, row)
    """
    
    data = map(lambda row: function(self.dataset, row), range(self.dataset.getNumRows()))
    values = scaleList(data, minValue, maxValue, colorRange[0], colorRange[1])
    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values
    
  def setColorByIndex(self, component = 'h', minValue=None, maxValue = None, colorRange=(0,.7)):

    """
    setColorByFunction(self, funct)

    your function needs to return a valid HSV tuple for each row of
    the dataset, given the dataset.
    ie.

    [listOfColorTuples] = function(ds)

    """
    values = scaleList(range(self.dataset.getNumRows()), minValue, maxValue, colorRange[0], colorRange[1])
    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values
  
  def setColors(self, list):

    """
    setColors(self, list)

    This takes in a list of HSV tuples and sets them to be the colors
    for plotting.

    """

    self.clearColors()
    self.__h = []
    self.__s = []
    self.__v = []
    for h,s,v in list:
      self.__h.append(h)
      self.__s.append(s)
      self.__v.append(v)

  def clearColors (self):
    """
    clearColors(self)

    clear set plot coloring.
    """
    self.__h = None
    self.__s = None
    self.__v = None

  def getColors(self, colorModel='hsv'):
    """
    getPlotColors(self)

    An implimentation of the only required colorMapper class.

    """
    assert colorModel == 'hsv'
    
    if self.__h == None:
      self.__h = [0] * self.dataset.getNumRows() 
    if self.__s == None:
      self.__s = [1] * self.dataset.getNumRows() 
    if self.__v == None:
      self.__v = [1] * self.dataset.getNumRows() 

    return(zip(self.__h, self.__s, self.__v))

class ProbabilityColorMapper(ColorMapper):

  """
  A fixed probabiliy color mapper
  """

  def __init__(self, plotView, probabilities):

    """
    
    """

    ColorMapper.__init__(self, plotView)
    self.dataset = self.getPlotView().getDataset()
    self.setProbabilityLabeling(probabilities)
    self.setHueByPartitioning()
    
    # these are the HSV lists that are used to color plots with.  

  def setProbabilityLabeling(self, probabilities):
      
    self.probabilities = probabilities

  def setHueByPartitioning(self, colorRange=(0,.7)):

    """
    setHueByPartioning.

    Colors each point's hue by the most likely cluster, and then it's
    value (intensity) is specified by the probability of membership to
    that particular cluster
    """
    
    self.__h = scaleList(map(MLab.argmax, self.probabilities.getLabelByRows()), minReturn=colorRange[0], maxReturn=colorRange[1])
    self.__s = scaleList(map(lambda p: p[MLab.argmax(p)] ,self.probabilities.getLabelByRows()), minReturn=0, maxReturn=1)
    self.__v = None

  def setHueByProbability(self, cluster, colorRange=(0,.7)):
      
    """
    setHueByProbability(self, cluster, colorRange(0,1)):
    
    where cluster is the index of the probability tuple
      
    """
    self.__h = scaleList(map(lambda x: x[int(cluster)],  self.probabilities.getLabelByRows()), minReturn=colorRange[0], maxReturn=colorRange[1])
    self.__s = None
    self.__v = None
    
  def getColors(self, colorModel='hsv'):
    assert colorModel == 'hsv'
    if self.__h == None:
      self.__h = map(lambda x: 0, range(0,self.dataset.getNumRows()))
    if self.__s == None:
      self.__s = map(lambda x: 1, range(0,self.dataset.getNumRows()))    
    if self.__v == None:
      self.__v = map(lambda x: 1, range(0,self.dataset.getNumRows()))

    return(zip(self.__h, self.__s, self.__v))
    
class RowDataMapper(IDataMapper):
  
  """
  RowDataMapper(Mapper):

  A simple data mapper which plots each row of the dataset has a line
  and seperate element.

  """

  def __init__(self, plotView, xAxisLabeling=None):

    IDataMapper.__init__(self, plotView)
    self.dataset = plotView.getDataset()
    self.setXAxisLabeling(xAxisLabeling)

  def setXAxisLabeling(self, labeling=None):
    """
    setXAxisLabeling(self, labeling=None)

    sets the xvalues for the plot.  If no labeling is specified then
    the x-axis values are set to be the index. 
    """

    if labeling is not None:
      self.__xvalues = tuple(labeling.getLabelByCols())
    else:
      self.__xvalues = tuple(range(0, self.dataset.getNumCols()))

  def getXValues(self):
    return(self.__xvalues)

  def getPlotData(self):

    ds = self.dataset
    data = zip((self.getXValues(),) * ds.getNumRows(), map(tuple, ds.getData()))
    return(data)

class SortedRowDataMapper(RowDataMapper):
  """
  SortedRowDataMapper(Mapper):

  datamapper which plots each row according to a sort-order specified
  """
  def __init__(self, plotView, xAxisLabeling=None):
    IDataMapper.__init__(self, plotView)
    self.dataset = plotView.getDataset()
    self.setXAxisLabeling(xAxisLabeling)
    self.sortedDataset = SortedView(self.dataset)

  def setXAxisLabeling(self, labeling=None):
    """
    sets xAxislabeling to be the globalWrapper labeling for the unsortedDataset
    """

    self.xAxisLabeling = labeling

  def sortByLabeling(self, labeling=None):
    """
    sets column-labeling on which sorting is performed
    """

    self.sortLabeling=labeling
    labels = labeling.getLabels()
    sortedCols = []
    for label in labels:
      sortedCols += self.sortLabeling.getColsByLabel(label)
    
    self.sortedDataset.permuteCols(sortedCols)

  def getPlotData(self):
    ds = self.sortedDataset
    xl = self.xAxisLabeling
    if xl:
      labeling = GlobalWrapper(ds, glabeling=self.xAxisLabeling.g)
      self.__xvalues = tuple(labeling.getLabelByCols())      
    else:
      labeling = None
      self.__xvalues = tuple(range(0, ds.getNumCols()))
      
    data = []
    for row in range(ds.getNumRows()):
      data.append((self.__xvalues, tuple(ds.getRowData(row))))

    return data

class MeanDataMapper(RowDataMapper, AnnotationMapper, BindingsMapper):

  """
  MeanStdDataMapper provides a view to look at the mean and std of
  either the whole datasets of groupings as specified by a labeling
  """

  def __init__(self, plotView, labeling=None):

    RowDataMapper.__init__(self, plotView)
    self.dataset = plotView.getDataset()
    self.__showStd = 0
    self.setLabeling(labeling)

  def setLabeling(self, labeling=None):
    """
    setLabeling(self, labeling)

    set the labeling to use when partitioning the dataset.
    """

    if labeling is None:
      if self.dataset.getLabeling('__meanDataMapper__'):
        labeling = self.dataset.getLabeling('__meanDataMapper__')
      else:
        labeling = Labeling(self.dataset, '__meanDataMapper__')
        labeling.labelRows  ([0]*self.dataset.getNumRows())
      self.__meanDataset = RowAggregateFunctionView(self.dataset, labeling, MLab.mean)
      self.__stdDataset  = RowAggregateFunctionView(self.dataset, labeling, safeStdDev)
    else:
      self.__meanDataset.setLabeling(labeling)
      self.__stdDataset.setLabeling(labeling)

    self.dataset = self.__meanDataset

  def getLineTrace(self):
    return('increasing')

  def getPlotData(self):
    ds = self.dataset
    if self.__showStd:
      meanData = self.__meanDataset.getData()
      stdData = self.__stdDataset.getData()
      data = zip([self.getXValues()]*3, map(tuple, Numeric.concatenate(((meanData - stdData ) ,
                                                                        meanData,
                                                                        (meanData+stdData)))))
    else:
      data = zip((self.getXValues(),) * ds.getNumRows(), map(tuple, ds.getData()))
      
    return(data)
    
  def showStdDev(self, state=None):

    """
    showStdDev(self, state)
    
    """
    if state == None:
      self.__showStd = not self.__showStd
    else:
      self.__showStd = state

class ColumnScatterDataMapper(IDataMapper):

  """
  ColumnScatterPlotView plots two columns of a dataset against
  each other.
  """

  def __init__(self, plotView, xCol=0, yCol=1):

    IDataMapper.__init__(self, plotView)

    self.dataset = plotView.getDataset()
    self.setXColumn(xCol)
    self.setYColumn(yCol)
      
  def getPlotData(self):
    return(zip(self.dataset.getColData(self.__xCol), self.dataset.getColData(self.__yCol)))

  def getXColumn(self):
    """
    returns the column of the date being plotting as X-data
    """
    return(self.__xCol)
    
  def getYColumn(self):
    """
    returns the column of the data being plotting as Y-data 
    """
    return(self.__yCol)

  def setXColumn(self, col):
    """
    setXColumn(self, row)

    sets row to be the column of the data to be plotting as X-data
    """
    self.__xCol = col
    
  def setYColumn(self, col):
    """
    setXColumn(self, col)
    
    sets row to be the column of the data to be plotting as X-data
    """
    self.__yCol = col

class PCADataMapper(IDataMapper):

  """
  ColumnScatterPlotView plots two columns of a dataset against
  each other.
  """

  def __init__(self, plotView, xCol=0, yCol=1):

    IDataMapper.__init__(self, plotView)

    self.dataset = RowPCAView(plotView.getDataset())
    self.setXColumn(xCol)
    self.setYColumn(yCol)
      
  def getPlotData(self):
    return(zip(self.dataset.getColData(self.__xCol), self.dataset.getColData(self.__yCol)))

  def getXColumn(self):
    return(self.__xCol)

  def getYColumn(self):
    return(self.__yCol)


  def setXColumn(self, col):
    """
    setXColumn(self, row)

    sets row to be the column of the data to be plotting as X-data
    """
    self.__xCol = col
    
  def setYColumn(self, col):
    """
    setXColumn(self, col)
    
    sets row to be the column of the data to be plotting as X-data
    """
    self.__yCol = col

#################
#
# The Fairly general purpose all around nice dataset plotView.
#
#########

class DatasetRowPlotView(IPlotView):

  """
  This class is a general purpose plot view.  Good for exploratory
  analysis.  It features pluggable Mappers to allow for on-the-fly
  exploration.
  
  """

  def __init__(self, dataset, primaryLabeling=None, secondaryLabeling=None):

    IPlotView.__init__(self, dataset)
    self._dataMapper       = RowDataMapper    (self)
    self._colorMapper      = RowColorMapper   (self)
    self._markersMapper    = RowMarkersMapper    (self)
    self._bindingsMapper   = SimpleBindingsMapper   (self)
    self._annotationMapper = RowAnnotationMapper (self, primaryLabeling, secondaryLabeling)

  def setDataMapper(self, dataMapper):

    """
    setDataMapper(self, dataMapper)
    """

    self._dataMapper = dataMapper

  def setColorMapper(self, colorMapper):

    """
    setColorMapper(self, colorMapper)
    """

    self._colorMapper = colorMapper

  def setMarkersMapper(self, markersMapper):

    """
    setColorMapper(self, colorMapper)
    """

    self._markersMapper = markersMapper

#########
#
# The back end plotter
#
####

class DatasetPlot(IPlot):

  """
  DatasetPlot class provides a common plotting tool for a variety
  of 2d plotting tasks.  PlotViews provide an extensible way to
  manipulate and control how data is rendered.

  """

  def __init__(self, parent=None, cnf={}, **kw):  
    self.currentPlotViews = [] 
    IPlot.__init__(self, parent, cnf, **kw)

  def plot(self, plotView=None, hold=0, pack=1):

    """
    This uesed the defined plotView interface to plot the data
    """

    replot = 0
    if plotView is None:
      plotView = self.currentPlotViews[0]
      replot = 1
    
    if hold==1:
      self.currentPlotViews.append(plotView)
    else:
      self.currentPlotViews = [plotView]
      replot=1
    plotView._setPlot(self)
    plotView.plotSetup()
    plotArgs = {}

    # get all the BLT direct plot args
    if type(plotView.getKWArgs()) == types.DictionaryType:
      plotArgs.update(plotView.getKWArgs())


    plotViewIndex = self.currentPlotViews.index(plotView)
    # get All the mappers:
    dataMapper   = plotView.getDataMapper()
    colorMapper  = plotView.getColorMapper()
    markersMapper = plotView.getMarkersMapper()
    bindingsMapper   = plotView.getBindingsMapper()
    annotationMapper = plotView.getAnnotationMapper()

    # extract the info first from the Plotview do it once vs. per
    # element to make things faster...  also its all done with
    # references so little memory overhead

    plotData = dataMapper.getPlotData()
    trace    = dataMapper.getLineTrace()
    weights  = dataMapper.getWeightsAndSytles()
    colors   = colorMapper.getColors()
    sizes    = markersMapper.getMarkerSizes()
    markers  = markersMapper.getPlotMarkers()
    
    primaryLabeling = annotationMapper.getPrimaryLabeling()
    markers  = markersMapper.getPlotMarkers()
    

    if primaryLabeling is None:
      rowNames = None
    else:
      rowNames = primaryLabeling.getLabelByRows()

    if type(sizes) not in [types.ListType, Numeric.ArrayType, types.TupleType]:
      sizes  = [sizes] *len (plotData)
    sizes = map(operator.add, map(str, sizes), ['i']*len(plotData))

    if type(markers) not in [types.ListType, Numeric.ArrayType, types.TupleType]:
      markers  = [markers] *len (plotData)

    # start plotting... 
    numCols = len(plotData[0])
    numRows = len(plotData)

    # FIXME - commented code uses BLT vectos which had some wierd referencing issues
    #         resulting in every instantiated vector overwriting the previous vector

    #previousXData = None
    #previousXVector = None
    for row in range(numRows):
      #v = Pmw.Blt.Vector(numCols)
      #v.set(plotData[row][1])
      #plotArgs ['ydata'] = v
      #if previousXData != plotData[row][0]:
      #  v = Pmw.Blt.Vector(numCols)
      #  v.set(plotData[row][0])
      #  plotArgs ['xdata'] = v
      #  previousXData = plotData[row][0]
      #  previousXVector = v
      #else:
      #  plotArgs ['xdata'] = previousXVector 

      plotArgs['xdata'] = plotData[row][0]
      plotArgs['ydata'] = plotData[row][1]
      plotArgs['pixels'] = sizes [row]
      plotArgs['symbol'] = markers [row]

      if colors is not None:
        plotArgs['color'] =  rgbToString(apply(colorsys.hsv_to_rgb,  colors[row]))

      if rowNames is None:
        name = row
      else:
        name = rowNames[row]
      #name =  '%s__%i'%(name,plotViewIndex)
      name = '%s'%(name)
      if weights is not None:
        plotArgs['weights'] = weights[row][0]
        plotArgs['styles']  = weights[row][1]

      if trace is not None:
        plotArgs['trace'] = trace
      # draw the line
      if not replot:
        # FIXME: apply(self.line_create, ["%s"%(name)], plotArgs)
        pass
      else:
        try:
          # FIXME: apply(self.element_configure, ["%s"%(name)], plotArgs)
          pass
        except:
          #FIXME: apply(self.line_create, ["%s"%(name)], plotArgs)
          pass
      
    bindingsMapper.addBindings()
    #FIXME: self.legend_configure(hide = 1)
    
    #if pack:
    #FIXME:  self.pack(expand=1, fill='both')

    plotView.plotFinishings()
    self.currentPlotView = [plotView]

####
#
# Big Plot Pages
#
########


class PlotPage:
  """
  A fairly simple minded scrolled frame to hold plots.

  """

  def __init__(self, parent=None, numRows=0, numCols=0):

    if parent is None:
      global _ROOT
      if _ROOT is None:
        startIPlot()
      self.parent= Tkinter.Toplevel(master=_ROOT)
    else:
      self.parent = parent

    self.sf = Pmw.ScrolledFrame(self.parent,
                                labelpos = 'n', label_text = 'PlotsPage',
                                usehullsize = 1,
                                hull_width = 600,
                                hull_height = 800,
                                horizflex= 'expand',
                                vertflex = 'expand'
                                )
    self.frame = self.sf.interior()
    self.plotSize = (200,200)
    self.__numRows = numRows
    self.__numCols = numRows
    self.plots = {}
    self.__focus = None
    self.__next = (0,0)
    self.sf.pack(padx = 5, pady = 3, fill = 'both', expand = 1)

  def setUniformYAxis(self, min=None, max=None):

    """
    setUniformYAxis(self)

    Finds the maximal range spanned by any plot on the plot page and
    sets the y-axis limists on all plots to its max and min

    """

    if (max is None) or (min is None):
      minY = 0
      maxY = 0
      # find the max and min
      for plotKey in self.plots.keys():
        range = self.plots[plotKey].axis_limits('y')
        if range[0] < minY:
          minY = range[0]
        if range[1] > maxY:
          maxY = range[1]
    else:
      maxY = max
      minY = min
      
    for plotKey in self.plots.keys():
      self.plots[plotKey].axis_configure('y', min=minY, max=maxY)

  def toggleZeros(self):

    """
    showZeros(self)

    toggles if each plot displays (0,0) lines.
    """

    for plotKey in self.plots.keys():
      p = self.plots[plotKey]
      xmin,xmax = p.xaxis_limits()
      ymin,ymax = p.yaxis_limits()
      if p.marker_exists ('xzero'):
        p.marker_delete('xzero')
        p.marker_delete('yzero')
      else:
        p.marker_create('line', name='xzero', coords=(xmin, 0, xmax, 0))
        p.marker_create('line', name='yzero', coords=(0, ymin, 0, ymin))

  def setOptimalYAxis(self):

    """
    setOptimalAxis(self)

    sets the axis such that all of the data is displayed.
    """
    
    for plotKey in self.plots.keys():
      self.plots[plotKey].axis_configure('y', min='', max='')
    
  def addPlot(self, row=None, col=None):

    """
    plotView = addPlot(self, row=None, col=None)
    adds a DatasetPlot to the page and returns a handle to it
    """

    dp = DatasetPlot(parent = self.frame)
    self.addWidget(dp, row, col)
    dp.configure(width=self.plotSize[0],height=self.plotSize[1] )
    return(dp)
    
  def addWidget(self, plot, row=None, col=None):

    """
    addWidget(widget, pos=None)

    places the widget into the plot page at pos (which is a row,col
    tuple).  If no pos is given, the next available plot location is
    used...  THis might require resizing the grid.
    
    """

    if row is None or col is None:
      pos = self.__next
    else:
      pos = (row,col)

    if self.__numRows < row:
      self.__numRows = row
    if self.__numCols < col:
      self.__numCols = col
    
    plot.parent = self.frame
    plot.grid(row = pos[0], column=pos[1], sticky='nsew')
    
    self.sf.reposition()
    self.plots[pos] = plot
    self.__focus= pos

    if pos[0]+1 > self.__numRows:
      self.__next = (0, pos[1]+1)
    else:
      self.__next = (pos[0]+1, pos[1])
  
  def setZoom(self, zoom):
    """
    resizes all the plots on the page by the propotation specified.
    """
    if zoom > 10:
      print "Choose a smaller zoom"
      return
    newSize = (self.plotSize[0]*zoom, self.plotSize[1]*zoom)
    for plotKey in self.plots.keys():
      self.plots[plotKey].configure(width=newSize[0],height=newSize[1])

class TrajectorySummary(PlotPage):

  """
  A fairly simple summary very of a labeling.
  """

  def __init__(self, dataset, clusterLabeling, primaryLabeling=None, secondaryLabeling=None, computeROC=1, parent=None):
    """
    __init__(self, dataset, clusterLabeling, primaryLabeling=None, secondaryLabeling=None, computeROC=1, parent=None):

    the primary and secondaryLabelings are assumed to be globalLabelings.
    
    """
    PlotPage.__init__(self, parent)
    # some resonable hard-coded plotting parameters.
    self.plotSize = (200,200)
    self.numPlotCols = 3
    
    # sets up the needed information
    self.dataset = dataset
    self.labeling = clusterLabeling
    self.primaryLabeling= primaryLabeling
    self.secondaryLabeling = secondaryLabeling


    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    self.means = AggregateFunctionView(self.dataset, keylists, MLab.mean)
    self.stds  = AggregateFunctionView(self.dataset, keylists, safeStdDev)
    
    # we have to recreate the keylists because it may have been modified in the AggregateFunctionView
    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    mapping   = [self.means._mapKeysToParent([key]) for key in  self.means.getRowKeys ()] 
    positions = [mapping.index(keylist) for keylist in keylists]
    self.aggregateOrder= {}
    map(self.aggregateOrder.setdefault, self.labeling.getLabels(), positions)


    #self.means = RowAggregateFunctionView(self.dataset, self.labeling, MLab.mean)
    #self.stds = RowAggregateFunctionView(self.dataset, self.labeling, safeStdDev)
    
    #if isinstance(self.labeling, GlobalWrapper):
    #  self.meansLab = GlobalWrapper(self.means, glabeling=self.labeling.g)
    #  self.stdLab = GlobalWrapper(self.stds, glabeling=self.labeling.g)

    #elif isinstance(self.labeling, GlobalLabeling):
    #  self.meansLab= GlobalWrapper(self.means, glabeling=self.labeling)
    #  self.stdsLab = GlobalWrapper(self.stds,  glabeling=self.labeling)

    #elif isinstance(self.labeling, labeling):
    #  self.meansLab = Labeling(self.means)
    #  self.stdsLab  = Labeling(self.stds)
    #  self.meansLab.labelFrom(self.labeling)
    #  self.stdsLab.labelFrom(self.labeling)
    #else:
    #  print "labeling of wrong type"
    #  return

    self.drawMeanStdTrajectories(computeROC)
    self.setUniformYAxis()

    # these are pointers to the full plot DatasetPlotter and its PlotView
    self.selectedPlot = None
    self.selectedView = None
    
  def drawMeanStdTrajectories(self, computeROC=1):


    clusterSizes = map(lambda x: (len(self.labeling.getRowsByLabel(x)), x), self.labeling.getLabels())
    clusterSizes.sort()
    plotOrder = [self.aggregateOrder[cluster] for size, cluster in clusterSizes if size > 0] 

    gridPos = [0,0]
    count = 0
    for row in plotOrder:
      plot = self.drawSummaryPlot(row, computeROC)
      plot.bind('<Button-1>',  lambda event, row=row:self.fullPlot(row)) 
      self.addWidget(plot, gridPos[0], gridPos[1])
      #button = Tkinter.Button(parent = self.frame,
      #                        text='Plot All',
      #                        command=(lambda row=row:self.fullPlot(row))) 
      #button.grid(row=gridPos[0]+1, column=gridPos[1])
      # work out the location of the plots.
      
      if (count+1) % self.numPlotCols == 0:
        gridPos[0] =  0
        gridPos[1] += 1
      else:
        gridPos[0] += 2
      count +=1
      
  def drawSummaryPlot(self, row, computeROC=1, titles=1, verbose=1):

    """
    drawSummaryPlot(self, row, roc=1)

    Draws the aggregate summary tragetory plot for the given row in
    the self.means dataset
    """
    # set up the basic cluster info
    #label = self.meansLab.getLabelsByRow(row)[0]
    label = [l for l,r in self.aggregateOrder.items() if r == row][0]
    if verbose:
      print "working on %s"%(str(label))
    mean = self.means.getRowData(row)
    posStd = tuple(mean + self.stds.getRowData(row))
    negStd = tuple(mean - self.stds.getRowData(row))
    mean = tuple(mean)
    xdata = tuple(range(len(mean)))
    if computeROC:
      if verbose:
        print "\t calculating ROC... ",
        rocArea = roc.clusterROC(self.dataset, self.labeling, label)[0]
        print "ROC area = %3.2f"%(rocArea)
    else:
      rocArea = NaN.nan
    size = len(self.labeling.getRowsByLabel(label))
    
    # draw the lines (don't need a fancy dataset plotter for these three lines
    if verbose:
        print "\t Drawing Plot"
    g = IPlot(parent = self.frame, subplot=True)
    g.line_create('mean', xdata=xdata, ydata=mean  , color='blue', pixels = "0.04i" )
    if posStd != mean:
      g.line_create('std1', xdata=xdata, ydata=posStd, color='red', pixels = "0.04i", dashes=(5,1))
      g.line_create('std2', xdata=xdata, ydata=negStd, color='red', pixels = "0.04i", dashes=(5,1))
    if titles:
      title = '<Cluster %s>\n  Count:%i, ROC: %3.2f'%(label, size, rocArea)
    else:
      title = ''
    g.configure(width=self.plotSize[0],height=self.plotSize[1], title=title)
    g.legend_configure(hide=1)
    g.pack(expand=1, fill='both')
    return(g)

  def fullPlot(self, row):
    """
    fullPlot(event, row)

    show the full row-wise plot of the cluster pointed to in the mean
    dataset at row
    """

    win = Tkinter.Toplevel()
    #label = self.meansLab.getLabelsByRow(row)[0]
    label = [l for l,r in self.aggregateOrder.items() if r == row][0]
    print "printing %s"%(str(label))
    subset = subsetByLabeling(self.dataset, self.labeling, label)
    dp = DatasetPlot(parent=win)
    dp.configure(title = "Cluster %s\n %i elements"%(str(label),subset.getNumRows()))
    rv = DatasetRowPlotView(subset, primaryLabeling=self.primaryLabeling, 
                                    secondaryLabeling=self.secondaryLabeling)
    rv.getColorMapper().setColorByColValue(0)
    dp.plot(rv)
    self.selectedPlot = dp
    self.selectedView = rv
    # FIXME: this is a memory leak
    #self.dataset.removeView(subset)
    
  def setZoom(self, zoom):
    """
    resizes all the plots on the page by the propotation specified.
    """
    if zoom > 10:
      print "Choose a smaller zoom"
      return
    newSize = (self.plotSize[0]*zoom, self.plotSize[1]*zoom)
    for plotKey in self.plots.keys():
      self.plots[plotKey].configure(width=newSize[0],height=newSize[1])

class ScatterMatrix(PlotPage):

  """
  Generates a plot page with a scatter plot between each dimension
  in a dataset.
  """

  def __init__(self, dataset, dims = None, primaryLabeling=None, secondaryLabeling=None, parent=None):

    if dims is None:
      dims = range(dataset.getNumCols())

    PlotPage.__init__(self, parent, numRows=len(dims), numCols =len(dims))
    self.plotSize=(100,100)
    self.dataset = dataset
    self.__plotView = DatasetRowPlotView(self.dataset, primaryLabeling, secondaryLabeling)
    self.__colorMapper = self.__plotView.getColorMapper()
    self.__dataMapper = ColumnScatterDataMapper(self.__plotView)
    self.__plotView.setDataMapper(self.__dataMapper)
    self.__makeScatterPlots(dims)
    
  def __makeScatterPlots(self, dims=None):
    
    if dims is None:
      dims = range(self.dataset.getNumCols())

    for x in dims:
      for y in dims:
        if self.plots.has_key((x,y)):
          p = self.plots[(x,y)]
        else:
          p = self.addPlot(x,y)
        self.__dataMapper.setXColumn(x)
        self.__dataMapper.setYColumn(y)
        p.plot(self.__plotView, pack=0)
        p.configure(width=self.plotSize[0],height=self.plotSize[1])
        p.legend_configure(hide=1)
        p.yaxis_configure(hide=1)
        p.xaxis_configure(hide=1)

  def getColorMapper(self):
    return(self.__colorMapper)

  def getPlotView(self):
    return(self.__plotView)

  def updatePlots(self):

    self.__makeScatterPlots()

class ConfusionMatrixSummary(TrajectorySummary):

  """
  Generates a trajectory summary for every cell in a confusion
  matrix
  
  """

  def __init__(self, dataset, labeling1, labeling2, primaryLabeling=None, secondaryLabeling=None,l1Order=None, l2Order=None, parent=None):
    """
    __init__(self, dataset, labeling1, labeling2, parent=None)
    
    """
    computeROC=0
    PlotPage.__init__(self, parent)
    # some resonable hard-coded plotting parameters.
    self.plotSize = (75,75)
    self.numPlotCols = 3
    
    # sets up the needed information
    self.dataset = dataset
    #self.cm = ConfusionMatrix()
    
    self.labeling1 = labeling1
    self.labeling2 = labeling2
    self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling
    self.l1Order = l1Order
    self.l2Order = l2Order
    
    # order the data by cluster sizes
    if self.l1Order is None:
      l1RowLabels = self.labeling1.getLabelByRows()
      l1Order = map(lambda x: (l1RowLabels.count(x), x), self.labeling1.getLabels())
      l1Order.sort()
      l1Order = map(lambda x: x[1], l1Order)
      l1Order.reverse()
      self.l1Order = l1Order
    else:
      l1Order = self.l1Order
      l1Order.reverse()

    if self.l2Order is None:
      l2RowLabels = self.labeling2.getLabelByRows()
      l2Order = map(lambda x: (l2RowLabels.count(x), x), self.labeling2.getLabels())
      l2Order.sort()
      l2Order = map(lambda x: x[1], l2Order)
      self.l2Order = l2Order
    else:
      l2Order = self.l2Order
   
    clusterOrders = [l1Order, l2Order]
    
    #self.cm.createConfusionMatrixFromLabeling(self.labeling1, self.labeling2)
    self.cm = ConfusionMatrix2.ConfusionMatrix([self.labeling1, self.labeling2], clusterOrders=clusterOrders)
    self.sf.configure(label_text="Confusion Matrix- NMI= %3.2f, NMI'= %3.2f, LA = %3.2f"%(self.cm.NMI(),
                                                                                          self.cm.transposeNMI(),
                                                                                          self.cm.linearAssignment()))
    # this is a labeling of confusion matrix coord
    #self.labeling = GlobalWrapper(self.dataset, 'ConfusionMatrix (%s %s)'%(labeling1.getName(), labeling2.getName() ))
    #for cell in self.cm.hypercube.keys():
    #  label = (self.cm.rowClassNames[cell[0]], self.cm.colClassNames[cell[1]])
    #  for row in self.cm.getConfusionHypercubeCell(cell):
    #    self.labeling.addLabelToRow(label, row)
    self.labeling = self.cm.getConfusionLabeling()
   
    # pre-initialize the summary views
    
    ## now to work out the aggregated datasets and the complicate row -> label mappings.  
    ## we can't use labelings, because they get confused and don't guerntee uniqueness 
    ## with the many-to-many mapping.
    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    self.means = AggregateFunctionView(self.dataset, keylists, MLab.mean)
    self.stds  = AggregateFunctionView(self.dataset, keylists, safeStdDev)
    
    # we have to recreate the keylists because it may have been modified in the AggregateFunctionView
    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    mapping   = [self.means._mapKeysToParent([key]) for key in  self.means.getRowKeys ()] 
    positions = [mapping.index(keylist) for keylist in keylists]
    self.aggregateOrder= {}
    map(self.aggregateOrder.setdefault, self.labeling.getLabels(), positions)

   
    #self.meansLab = Labeling(self.means, 'meansLab')
    #self.meansLab.labelRows(aggregateOrder)
    #self.stdLab = Labeling(self.means, 'stdLab')
    #self.stdLab.labelRows(aggregateOrder)

    #if isinstance(self.labeling, GlobalWrapper):
    #  self.meansLab = GlobalWrapper(self.means, glabeling=self.labeling.g)
    #  self.stdLab = GlobalWrapper(self.stds, glabeling=self.labeling.g)

    #elif isinstance(self.labeling, GlobalLabeling):
    #  self.meansLab= GlobalWrapper(self.means, glabeling=self.labeling)
    #  self.stdsLab = GlobalWrapper(self.stds,  glabeling=self.labeling)

    #elif isinstance(self.labeling, labeling):
    #self.meansLab = Labeling(self.means)
    #self.stdsLab  = Labeling(self.stds)
    #self.meansLab.labelFrom(self.labeling)
    #self.stdsLab.labelFrom(self.labeling)
    #else:
    #  print "labeling of wrong type"
    #  return

    
    
    
    # draw the summarys
    self.drawMeanStdTrajectories(computeROC)
    self.setUniformYAxis()

    # these are pointers to the full plot DatasetPlotter and its PlotView
    self.selectedPlot = None
    self.selectedView = None
  
  def drawMeanStdTrajectories(self, computeROC=1):
    """
    drawMeanStdTrajectories(self, computeROC=1)

    This method draws all the cells of the confusion matrix with small
    summary tragetories in each one.
    """
    
    # order the data by cluster sizes
    # if self.l1Order is None:
    #   l1RowLabels = self.labeling1.getLabelByRows()
    #   l1Order = map(lambda x: (l1RowLabels.count(x), x), self.labeling1.getLabels())
    #   l1Order.sort()
    #   l1Order = map(lambda x: x[1], l1Order)
    #   l1Order.reverse()
    #   self.l1Order = l1Order
    # else:
    #   l1Order = self.l1Order
    #   l1Order.reverse()

    # if self.l2Order is None:
    #   l2RowLabels = self.labeling2.getLabelByRows()
    #   l2Order = map(lambda x: (l2RowLabels.count(x), x), self.labeling2.getLabels())
    #   l2Order.sort()
    #   l2Order = map(lambda x: x[1], l2Order)
    #   self.l2Order = l2Order
    # else:
    #   l2Order = self.l2Order
    
    l1Order, l2Order = self.cm.getClusterOrders() 
    # now draw the pretty confusion matrix.
    gridPos = [1,0]   # keep track of position in the Tk grid
    matrixPos = [0,0] # keep track of position in the matrix 
    for lab1 in l1Order:
      for lab2 in l2Order:
        #dataRow = self.meansLab.getRowsByLabel((lab1, lab2))
        dataRow = self.aggregateOrder.get((lab1, lab2))
        if dataRow is not None:
          print "found %s"%(str((lab1,lab2)))
          plot = self.drawSummaryPlot(dataRow, 'white', computeROC=0, titles=0, verbose=0)
          self.addWidget(plot, gridPos[0], gridPos[1])
        else:
          print "not found %s"%(str((lab1,lab2)))
          blank = Tkinter.Label(master=self.frame, bg = 'white', relief='sunken')
          blank.grid(row=gridPos[0], column=gridPos[1] )
        gridPos[1] += 1
        matrixPos[1] +=1

      gridPos[0] += 1
      gridPos[1] = 0
      matrixPos[0] +=1
      matrixPos[1] = 0

    # add the adjancy elements
    adjList = map(lambda pair :
                  (l1Order.index(pair[0])+1, l2Order.index(pair[1])),
                  self.cm.getAdjacencyList())
    
    for pair in adjList:
      self.plots[pair].configure(background = "black" )

    # draw the marginal cluster summaries.
    col = 0
    for lab2 in l2Order:
      data = Numeric.array(map(self.dataset.getRowData,
                               self.labeling2.getRowsByLabel (lab2)))
      if Numeric.shape(data) != (0,):
        plot = self.drawMarginalPlot(data, lab2)
        self.addWidget(plot, 0, col)
      col +=1 

    row = 1
    for lab1 in l1Order:
      data = Numeric.array(map(self.dataset.getRowData,
                               self.labeling1.getRowsByLabel (lab1)))
      if Numeric.shape(data) != (0,):
        plot = self.drawMarginalPlot(data, lab1)
        self.addWidget(plot, row, col)
      row +=1 
    
    self.colorByColumnMarginals()

  def colorByInterclusterROC(self):
    """ Color each cell by the pairwise ROC area """
    
    gridPos = [1,0]
    for lab1 in self.l1Order:
      for lab2 in self.l2Order:
        RowRocValue = roc.interclusterROC(self.dataset, 
                                           self.labeling1, lab1, 
                                           self.labeling2, lab2)
        ColRocValue = roc.interclusterROC(self.dataset,
                                           self.labeling2, lab2,
                                           self.labeling1, lab1)
        print "%s vs %s ROC Area = %3.2f/%3.2f"%(str(lab1), 
                                                 str(lab2), 
                                                 RowRocValue[0],
                                                 ColRocValue[0],)
        rowColor = rgbToString(colorsys.hsv_to_rgb(RowRocValue[0]*.7, 1, 1))
        colColor = rgbToString(colorsys.hsv_to_rgb(ColRocValue[0]*.7, 1, 1))
        if self.plots.has_key(tuple(gridPos)):
          plot = self.plots[tuple(gridPos)]
          if not plot.marker_exists('rocUpper'):
            plot.marker_create('polygon', name='rocUpper')
          if not plot.marker_exists('rocLower'):
            plot.marker_create('polygon', name='rocLower')
          xmin, xmax = plot.xaxis_limits()
          ymin, ymax = plot.yaxis_limits()
          plot.marker_configure('rocUpper',
                                coords = (xmin,ymin, xmax, ymax, xmin, ymax, xmin, ymin),
                                fill = colColor,
                                under = 1)
          plot.marker_configure('rocLower',
                                coords = (xmax, ymin, xmin,ymin, xmax, ymax, xmax, ymin),
                                fill = rowColor,
                                under = 1)
          plot.configure(plotbackground='white') 
        gridPos[1]+=1
      gridPos[0]+=1
      gridPos[1] =0
  
  def colorByColumnMarginals(self):
    """Color each cell of the confusion matrix by how many members are in it relitive to the indicated direction """
   # now to get the ordered confusion matrix counts.
    originalMatrix = Numeric.array(self.cm.getMatrix())
    m = []
    rowCount = 0
    rowOrder, colOrder = self.cm.getClusterOrders() 
    for row in self.l1Order:
      m.append([])
      for col in self.l2Order:
        m[rowCount].append(originalMatrix[rowOrder.index(row),
                                          colOrder.index(col)])
      rowCount+=1

    # normalize the matrix.
    matrix = Numeric.array(m)
    matrix = matrix.astype(Numeric.Float)
    matrix = matrix / (Numeric.sum(matrix)+.0001)   # solves the problem of occasionaly dividing by zero
    del(m)
    del(originalMatrix)
   
    # now draw the pretty confusion matrix.
    gridPos = [1,0]   # keep track of position in the Tk grid
    matrixPos = [0,0] # keep track of position in the matrix 
    for lab1 in self.l1Order:
      for lab2 in self.l2Order:
        color = rgbToString(colorsys.hsv_to_rgb(matrix[matrixPos]*.7, 1, 1))
        if self.plots.has_key(tuple(gridPos)):
          plot = self.plots[tuple(gridPos)]
          plot.configure(plotbackground=color)
          plot.marker_delete('rocUpper')
          plot.marker_delete('rocLower')
        gridPos[1] += 1
        matrixPos[1] +=1
      gridPos[0] += 1
      gridPos[1] = 0
      matrixPos[0] +=1
      matrixPos[1] = 0

  def colorByRowMarginals(self):
    # now to get the ordered confusion matrix counts.
    originalMatrix = Numeric.array(self.cm.getMatrix())
    m = []
    rowCount = 0
    rowOrder, colOrder = self.cm.getClusterOrders() 
    for row in self.l1Order:
      m.append([])
      for col in self.l2Order:
        m[rowCount].append(originalMatrix[rowOrder.index(row),
                                          colOrder.index(col)])
      rowCount+=1

    # normalize the matrix.
    matrix = Numeric.array(m)
    matrix = matrix.astype(Numeric.Float)
    matrix = Numeric.transpose(Numeric.transpose(matrix) / (Numeric.sum(matrix,1)+.0001))   # solves the problem of occasionaly dividing by zero
    del(m)
    del(originalMatrix)
   
    # now draw the pretty confusion matrix.
    gridPos = [1,0]   # keep track of position in the Tk grid
    matrixPos = [0,0] # keep track of position in the matrix 
    for lab1 in self.l1Order:
      for lab2 in self.l2Order:
        color = rgbToString(colorsys.hsv_to_rgb(matrix[matrixPos]*.7, 1, 1))
        if self.plots.has_key(tuple(gridPos)):
          plot = self.plots[tuple(gridPos)]
          plot.configure(plotbackground=color)
          plot.marker_delete('rocUpper')
          plot.marker_delete('rocLower')
        gridPos[1] += 1
        matrixPos[1] +=1
      gridPos[0] += 1
      gridPos[1] = 0
      matrixPos[0] +=1
      matrixPos[1] = 0
    
  def drawMarginalPlot(self, data, label):

    """
    drawMarginalPlot(self, data)

    draws the marginal plots for the confusion matrix summary.
    """
    mean = MLab.mean(data)
    posStd = tuple(mean + safeStdDev(data))
    negStd = tuple(mean - safeStdDev(data))
    mean = tuple(mean)
    xdata = tuple(range(len(mean)))

    g = IPlot(parent = self.frame, subplot=True)
    g.line_create('mean', xdata=xdata, ydata=mean  , color='blue', pixels = "0.0" )
    if posStd != mean:
      g.line_create('std1', xdata=xdata, ydata=posStd, color='red', pixels = "0.0", dashes=(5,1))
      g.line_create('std2', xdata=xdata, ydata=negStd, color='red', pixels = "0.0", dashes=(5,1))

    g.configure(width=self.plotSize[0],height=self.plotSize[1])

    # make a title -
    g.marker_create('text', name='size')
    g.marker_create('text', name='name')
    #coords = ( (g.xaxis_limits()[1] - g.xaxis_limits()[0])/2 , g.yaxis_limits()[1])
    g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]),
                       text=str(len(data)),
                       anchor='w')
    g.marker_configure('name', coords = (g.xaxis_limits()[1], g.yaxis_limits()[0]), 
                       text=label,
                       anchor='e')


    g.legend_configure(hide=1)
    g.yaxis_configure(hide=1)
    g.xaxis_configure(hide=1)
    g.pack(expand=1, fill='both')


    return(g)
    
  def drawSummaryPlot(self, row, color, computeROC=1, titles=1, verbose=1):

    g = TrajectorySummary.drawSummaryPlot(self, row, computeROC, titles, verbose)
    g.bind('<Button-1>',  lambda event, row=row:self.fullPlot(row)) 
    g.bind('<Button-2>',  lambda event, row=row:self.fullPlot(row)) 
    g.configure(plotbackground = color)
    g.yaxis_configure(hide=1)
    g.xaxis_configure(hide=1)
    g.element_configure('mean', color='black', pixels='0')
    if 'std1'  in g.element_show():
      g.element_configure('std1', hide=0 , color='black', pixels='0', dashes = (1,1) )
      g.element_configure('std2', hide=0 , color='black', pixels='0', dashes = (1,1) )

    # make a title -
    g.marker_create('text', name='size')
    
    g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]),
                       text=str(len(self.means._mapKeysToParent([row]))),
                       anchor='w')
    
    return(g)

  def setUniformYAxis(self, min=None, max=None):
    PlotPage.setUniformYAxis(self, min, max)
    for g in self.plots.values(): 
      try:
        g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]))
      except:
        pass
      try:
        g.marker_configure('name', coords = (g.xaxis_limits()[1], g.yaxis_limits()[0]))
      except:
        pass

  def setOptimalYAxis(self):
    PlotPage.setOptimalYAxis(self)
    for g in self.plots.values(): 
      try:
        g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]))
      except:
        pass
      try:
        g.marker_configure('name', coords = (g.xaxis_limits()[1], g.yaxis_limits()[0]))
      except:
        pass

class HistogramPlotter(IPlot):

  """
  A nice interactive histogram tool.
  """


  def __init__(self, labeling=None, seriesName=None, primaryLabeling=None, secondaryLabeling=None, sortBy = 'labels', plot = 1, parent=None, cnf={}, **kw):  
    IPlot.__init__(self, parent, cnf, **kw)
    self.configure(barmode="aligned")
    self.primaryLabeling=primaryLabeling
    self.secondaryLabeling=secondaryLabeling
    self.labelings = [] 
    self.seriesNames= []
    self.element_bind('all', '<Button-2>', self.fullPlot)
    if labeling:
      self.addLabeling(labeling, seriesName)
      self.plot(sortBy=sortBy)
    
  def addLabeling(self, labeling, seriesName=None):
    self.labelings.append(labeling)
    self.seriesNames.append(seriesName)

  def plot(self, sortBy='labels', color=1, pack=1):

    """
    plot(self, sortXBy='labels')

    sortXBy can be any of these: ['labels', 'size']
    """

    values = scaleList(range(len(self.labelings)),
                       minReturn=0, maxReturn=.7)

    colors = map(rgbToString,
                 map(colorsys.hsv_to_rgb,
                     values, [1]*len(values), [1]*len(values)))
      
    
    count = 0
    for labeling, name, color in zip( self.labelings, self.seriesNames, colors):
      self.xlabels = labeling.getLabels()
      ydata = map(len, map(labeling.getRowsByLabel,  self.xlabels))
      
      if sortBy == 'labels':
        data = zip(self.xlabels, ydata)
        data.sort()
        self.xlabels = map(lambda x: x[0], data)
        ydata   = map(lambda x: x[1], data)

      elif sortBy == 'size':
        data = zip(ydata, self.xlabels)
        data.sort()
        ydata   = map(lambda x: x[0], data)
        self.xlabels = map(lambda x: x[1], data)

      ydata   = tuple(ydata)
      self.xlabels = tuple(self.xlabels)
      xdata   = tuple(range(len(self.xlabels)))
      if name is None:
        name = "%s: %s"%(str(count), str(labeling.getName()))
      else:
        name = "%s: %s"%(str(count), name)
      try:
        self.bar_create(name, ydata=ydata, xdata=xdata, fg=color, bg=color)
      except:
        self.element_configure(name, ydata=ydata, xdata=xdata, fg=color, bg=color)
      count +=1
      try:
        tickLabels = ["%3.2f"%(self.xlabels[index]) for index in xdata]
      except:
        tickLabels = [str(self.xlabels[index]) for index in xdata]
      
      if MLab.max(map(len, tickLabels)) > 4:
        tickrotation=90
      else:
        tickrotation=0 
      self.xaxis_configure(command=lambda widgetPath, ticLabel, xlabels=tickLabels: xlabels[int(ticLabel)],
                        rotate=90,
                        stepsize=1,
                        minorticks=0)
                        #tickfont = "*-Courier-Bold-R-Normal-*-140-*")  

  
      
    if pack:
      self.pack(expand=1, fill='both')
      
  def fullPlot(self, event):
    """
    fullPlot(event, label)

    show the full row-wise plot of the cluster pointed to in the mean
    dataset at row
    """
    global ev
    ev = event
    win = Tkinter.Toplevel()
    index = event.widget.element_closest(event.x, event.y)['index']
    x = event.widget.element_closest (event.x, event.y)['x']
    y = event.widget.element_closest (event.x, event.y)['y']
    name   = event.widget.element_closest (event.x, event.y)['name']

    print name.split(": ")[1]
    labeling = self.labelings[int(name.split(": ")[0])]
    dataset = labeling.getDataset()
    label = self.xlabels[index]

    print "printing %s"%(str(label))
    subset = subsetByLabeling(dataset, labeling, label)

    dp = DatasetPlot(parent=win)
    dp.configure(title = "Group %s\n from %s \n %i elements"%(str(label),name,subset.getNumRows()))
    rv = DatasetRowPlotView(subset, primaryLabeling=self.primaryLabeling, 
                                    secondaryLabeling=self.secondaryLabeling)
    dp.plot(rv)
    self.selectedPlot = dp
    self.selectedView = rv
    # FIXME: this is a memory leak
    #self.dataset.removeView(subset)

##
#
# Just a few module helper functions
#

def rgbToString(rgb):

  """
  rgbToString(rgb):

  turns an rgb tuple into a hex string
  """
  colorString = ''
  for hue in rgb:
    tmpString = hex(int(hue*255))[2:]
    if len(tmpString) <2:
      tmpString = "0%s"%(tmpString)
    colorString += tmpString
  color = "#%s"%(colorString)
  return(color)

def scaleList(values, minValue=None, maxValue=None, minReturn=0, maxReturn=1):

  """
  scaleList(values, minReturn=0, maxReturn=1, minValue=None, maxValue=None)

  Given a list of objects, returns a list of values returns a list of
  values scalled between 0 and 1 where minValue -> 0 and maxValue ->
  1).  If values is non-numeric each unique item is mapped arbitrarly
  into the 0..1 range.  Values must be a list of hashable python
  objects

  """

  labels =  unique.unique (values)
  numericLabels = 1
  for label in labels:
    labelType = type(label)
    if labelType not in [types.IntType, types.LongType, types.FloatType]:
      numericLabels = 0
      break

  if numericLabels:
    try:
      labels = Numeric.array(labels)
      data = Numeric.array(values)
    except:
      numericLabels = 0

  if numericLabels:
    if minValue is None:
      minValue = MLab.min(labels)
    if maxValue is None:
      maxValue = MLab.max(labels)
  else:
    minValue = 0
    maxValue = len(labels)
    labelMap = {}
    for label, count in zip(labels, range(maxValue)):
      labelMap[label]= count
    data  = map(lambda x: labelMap[x], values)

  # perhaps there is a more elegent way of doing this... But this
  # works to scale the data between minValue and maxVale
  data = Numeric.array(data)
  data = Numeric.array(map(lambda x: max(x,0), data - minValue))
  data = data.astype('f')
  data = data / ((maxValue-minValue)+.0001)
  data = Numeric.array(map(lambda x: min(x, 1), data))
  
  # thie maps the data into the minReturn..maxReturn range.
  data = (data * (maxReturn - minReturn)) + minReturn

  return(data)

def plot(values, yvalues=None,
         xerror=None, yerror=None,
         xmin=None, xmax=None, ymin=None, ymax=None,
         plotStyle="line", color=None, fileName=None,
         seriesName = None,  previousPlot=None , parent=None, pack=1):

    """
    A wrapper around the IPlot class to provide a fast easy interface for plotting.
    
    Usage: plot(x,y, <options>)
              This creates a plot of the x-vector vs the
              y-vector.  x and y can be either numeric arrays
              r standard python lists

           plot(y , <options>)

              if y is a 1d array/list/tuple it creates a plot of the values vs thier index
              if y is a 2d array/list/tuple the values of each row is plotted as a data series vs their index
              if y is a dataset, it is treated like a 2d array, only a plotView and a datasetPlotter is returned.

        optional Parameters:
              plotStyles -> on of the following: 'line', 'points', 'bar'
              xmin -> float/int
              ymin -> float/int
              xmax -> float/int
              ymax -> float/int
              xerror -> NumericArray of len(values)/ if values is a 2d array, xerror should be the same size
              yerror -> NumericArray of len(values)/ if values is a 2d array, yerror should be the same size
              fileName -> name of postscript file to create
              color -> RGB string color (ie '#FFFFFF for black') or ['red', 'blue', 'green', 'orange', 'black', 'white', 'yellow'].  

    """
    # this little block calculates numColors then generates the ordered list such that simular colors are well seperated.
    # This is fairly good for arbitrary coloring of lines.

    if color is None:
      numColors = 7 
      #colors = map(rgbToString,
      #             [colorsys.hsv_to_rgb(h,s,v) for h,s,v in
      #              zip(Numeric.arange(0, .7, .7/numColors), [1]*numColors, [1]*numColors)])
      colors = ['blue', 'red', 'green', 'orange', 'yellow', 'cyan', 'purple',]
    
    if previousPlot:
      g = previousPlot
      numElements = len(g.element_names())
      if color is None:
        index = numElements%(numColors/3)+((numElements%3)*numColors/3)
        color = colors[index]
    else:
      g = IPlot(parent=parent)
  

    if isinstance(values, Dataset):
      dp = DatasetPlot()
      v = DatasetRowPlotView(values)
      dp.plot(v)
      return(dp,v)
    elif isinstance(values, sciHistogram.Histogram):
      plotStyle='bar'
      g.configure(barwidth=str(values.bin_width))
      yvalues = values.array[:,1]
      data = values.array[:,0]
    else:
      try:
        data = Numeric.array(values)
      except:
        sys.stderr.write("Not a supported plotting type\n")
        return

    if len(Numeric.shape(data)) == 1:
      #single vector to be plotted against its index

      if yvalues is None:
        yvalues = tuple(data)
        xvalues = tuple(range(len(yvalues)))
      else:
        xvalues = tuple(data)
        yvalues = tuple(yvalues)
      if seriesName is None:
        seriesName = str(len(g.element_names()))

      args = {}
      if plotStyle == 'line':
        g.line_create(seriesName, xdata = xvalues, ydata = yvalues, color=color)
      elif plotStyle == 'points':
        g.line_create(seriesName, xdata = xvalues, ydata = yvalues, color=color, linewidth=0)
      elif plotStyle == 'bar':
        g.bar_create(seriesName, xdata = xvalues, ydata = yvalues, background=color, foreground=color)
      else:
        sys.stderr.write("Unknown plot style (%s), try: 'line', 'bar', or 'points'\n"%(plotStyle))

    else:
      # plot each row in the datamatrix againts yvalues or the index
      for i in range(len(data)):
        if xerror:
          xerrTmp = xerror[i]
        else:
          xerrTmp = None
        if yerror:
          yerrTmp = yerror[i]
        else:
          yerrTmp = None
        if color:
          colorTmp = color[i]
        else:
          colorTmp = None
        plot(data[i],
             xerror=xerrTmp,
             yerror=yerrTmp,
             xmin = xmin, xmax= xmax,
             ymin = ymin, ymax = ymax,
             plotStyle = plotStyle,
             color = colorTmp,
             previousPlot=g)

    if pack: 
      g.pack(expand=1, fill='both')

    if fileName:
        g.postscript(fileName)
    return(g)
  
def boxPlot(ds, parent=None, dimension=0):
  """
  Generate a box plot, dimension = 0 (column-wise) dimension=1 (row-wise) 
  ds can either be a dataset or an numeric array
  """
  if not isinstance(ds, Dataset):
    ds = Dataset(ds)
  
  data = ds.getData()  
  if dimension == 1:
    data = Numeric.transpose(data)
 
  ## compute the statistics
  
  rows, cols = Numeric.shape(data)
  medians = MLab.median(data)
  means = MLab.mean(data)
  sortedData = Numeric.sort(data,0)
  quartiles25 = sortedData[int(rows*.25),:]
  quartiles75 = sortedData[int(rows*.75),:]
  iqr = quartiles75-quartiles25
  
  upperOutliers = [filter(lambda x: x > quartiles75 + (iqr*1.5), sd) for sd in Numeric.transpose(sortedData)] 
  upperIQRextreme = []
  for set,sd in zip(upperOutliers, Numeric.transpose(sortedData)):
    if len(set)>0:
      upperIQRextreme.append(set[0])
    else:
      upperIQRextreme.append(sd[-1])
  
  lowerOutliers = [filter(lambda x: x < quartiles25 - (iqr*1.5), sd) for sd in Numeric.transpose(sortedData)]
  lowerIQRextreme = []
  for set, sd in zip(lowerOutliers, Numeric.transpose(sortedData)):
    if len(set)>0:
      lowerIQRextreme.append(set[-1])
    else:
      lowerIQRextreme.append(sd[0])
  
  ## render the plot
  g = IPlot(parent=parent)
  g.configure(title='Box Plot')
  g.legend_configure(hide=1)
  g.xaxis_configure(min=-1, max=cols+1)

  g.line_create('means',  xdata=tuple(range(cols)),
                          ydata = tuple(means),
                          linewidth=0,
                          pixels='0.04i',
                          color='black',
                          symbol='circle')
  
  g.line_create('25thQuarts', xdata=tuple(listOps.unravel([[x,x] for x in range(cols)])),
                              ydata=tuple(listOps.unravel(map(lambda x,y: [x,y], medians, quartiles25 ))),
                              linewidth=40,
                              pixels='0.0i',
                              color='lightblue',
                              trace='decreasing')
  g.line_create('75thQuarts', xdata=tuple(listOps.unravel([[x,x] for x in range(cols)])),
                              ydata=tuple(listOps.unravel(map(lambda x,y: [x,y], medians, quartiles75 ))),
                              linewidth=40,
                              pixels='0.0i',
                              color='blue',
                              trace='decreasing')
  
  g.line_create('iqrHigh', xdata=tuple(listOps.unravel([[x,x] for x in range(cols)])),
                           ydata=tuple(listOps.unravel(map(lambda x,y: [x,y], 
                                                           quartiles75, upperIQRextreme ))),
                           linewidth=3,
                           pixels='0.0',
                           color='black',
                           trace='decreasing')
  g.line_create('iqrLow', xdata=tuple(listOps.unravel([[x,x] for x in range(cols)])),
                          ydata=tuple(listOps.unravel(map(lambda x,y: [x,y], 
                                                          quartiles25, lowerIQRextreme ))),
                          linewidth=3,
                          pixels='0.0',
                          color='black',
                          trace='decreasing')
                           
  g.line_create('upperOutliers', xdata=tuple(listOps.unravel([[x]*len(yvalues) for x,yvalues in zip(range(cols),upperOutliers)])),
                                 ydata=tuple(listOps.unravel(upperOutliers)),
                                 linewidth=0,
                                 pixels ='.08i',
                                 symbol = 'scross',
                                 color  = 'red')

  g.line_create('lowerOutliers', xdata=tuple(listOps.unravel([[x]*len(yvalues) for x,yvalues in zip(range(cols),lowerOutliers)])),
                                 ydata=tuple(listOps.unravel(lowerOutliers)),
                                 linewidth=0,
                                 pixels ='.08i',
                                 symbol = 'scross',
                                 color  = 'red')
 
  g.pack(expand=1, fill='both')
  return(g)
 
class ROCPlot(PlotPage):
  def __init__(self, dataset, labeling, label, distanceMetric=DistanceMetrics.EuclideanDistance, parent=None):
    """
    draw our standard ROC plot for the label in labeling for dataset 
    """
    PlotPage.__init__(self, parent=parent)
    # get the ROC stats and start a plotpage
    area, xvalues, yvalues = roc.clusterROC(dataset, labeling, label, distanceMetric=distanceMetric)
    self.sf.configure(label_text='ROC Display for label %s in %s\n Area=%3.2f'%(str(label), str(labeling.getName()),area))
    
    # make the ROC curve 
    areaCurve = plot(xvalues, yvalues, seriesName='ROC curve', parent=self.frame, pack=0)
    areaCurve.element_configure('ROC curve', areapattern='solid', 
                                areaforeground='lightblue', pixels='.02i') 
    areaCurve.legend_configure(hide=1)
    areaCurve.configure(title='ROC Curve', height=400,width=400)
    areaCurve.xaxis_configure(title='% outside')
    areaCurve.yaxis_configure(title='% inside')
    self.addWidget(areaCurve)
   
    ## build the interactive histogram plots
    ### These computations are partially redundent with the ROC code..   **
    data = dataset.getData()
    insideRows = labeling.getRowsByLabel(label)
    outsideRows = listOps.difference(range(dataset.getNumRows()), insideRows)
    clusterMean = MLab.mean(Numeric.take(data, insideRows))
    distances = distanceMetric(clusterMean, data)
    l = Histogram.binOnRowVector(dataset, distances, dataset.getNumRows()/10)
    distanceLab = GlobalLabeling(dataset, '__distances__')
    ## FIXME globalLabelings.labelFrom is broken 
    for label in l.getLabels():
      distanceLab.addLabelToRows(dataset, label, l.getRowsByLabel(label))
   
    dataset.removeLabeling(l)
    insideView = RowSubsetView(dataset, insideRows)
    outsideView =RowSubsetView(dataset, outsideRows)
    histograms = HistogramPlotter(parent=self.frame)
    insideLab = GlobalWrapper(insideView, glabeling=distanceLab)
    outsideLab = GlobalWrapper(outsideView, glabeling=distanceLab)
    histograms.addLabeling(insideLab, seriesName='Inside')
    histograms.addLabeling(outsideLab, seriesName='Outside')
    histograms.plot()
    histograms.configure(barmode='aligned')
    histograms.configure(title='Histogram of Distances from Cluster Center')
    histograms.xaxis_configure(title='Distance')
    histograms.yaxis_configure(title='Count')
    self.addWidget(histograms)
   
class PCAExplorer:
  """
  A simple exploritory tool to better understand the signifigance of PCA
  """

  def __init__(self, dataset, primaryLabeling=None, secondaryLabeling=None, parent=None):
    self.pcaDS = RowPCAView(dataset)
    if parent is not None:
      toplevel1 = Tkinter.Toplevel(parent)
      self.pp = PlotPage(parent=toplevel1)
    else:
      self.pp = PlotPage(parent=None)
    self.scales = []
    self.buttons = []
    variances = self.pcaDS.getVariances()
    rotMatrix = self.pcaDS.matrix

    if parent is not None:
      toplevel2 = Tkinter.Toplevel(parent)
      self.nativePlot = IPlot(toplevel2)
    else:
      self.nativePlot = IPlot()
    self.nativePlot.line_create('point', ydata=tuple([0]*len(rotMatrix)),
                                         xdata=tuple(range(len(rotMatrix)))
                                         )
    self.nativePlot.configure(title='Native Space')
    self.nativePlot.legend_configure(hide=1)
    self.nativePlot.pack(expand=1, fill='both')
   
    self.v = DatasetRowPlotView(dataset, 
                                primaryLabeling=primaryLabeling, 
                                secondaryLabeling=secondaryLabeling)
    self.dm = PCADataMapper(self.v)
    self.v.setDataMapper(self.dm)
    self.pcaFrame = Tkinter.Toplevel(parent)
    pcaDimLabels = ['PC-%i (%3.2f%%)'%(i,v) for i,v in zip(range(len(variances)), variances)]
    self.pcaControlFrame = Tkinter.Frame(self.pcaFrame)
    self.pcaXChooser = Pmw.OptionMenu(self.pcaControlFrame, 
                                      labelpos = 'w',
                                      label_text = 'X-Axis:',
                                      command=self.__xChooserCB,
                                      items=pcaDimLabels)
    self.pcaXChooser.pack(anchor='s', side='left', expand=1, fill='x')
    self.pcaYChooser = Pmw.OptionMenu(self.pcaControlFrame, 
                                      labelpos = 'w',
                                      label_text = 'Y-Axis:',
                                      command=self.__yChooserCB,
                                      items=pcaDimLabels)
    self.pcaYChooser.pack(anchor='s', side='left', expand=1, fill='x')

    self.pcaPlot = DatasetPlot(parent=self.pcaFrame)
    self.pcaPlot.pack(side='top', anchor='w', expand=1, fill='both')
    self.pcaControlFrame.pack(side='top', anchor='w', expand=1, fill='x')
    self.pcaPlot.line_create('point', ydata=(0,), xdata=(0,), symbol='cross', pixels='.22i')
    self.pcaPlot.plot(self.v)
    self.pcaPlot.element_activate('point')
    self.pcaPlot.legend_configure(hide=1)
    self.pcaPlot.configure(title='PCA Space')
    self.pcaPlot.xaxis_configure(title='PC-1')
    self.pcaPlot.yaxis_configure(title='PC-2')
    self.pcaPlot.bind('<ButtonRelease-2>', self.plotUpdate)

    for i in range(len(self.pcaDS.matrix)):
      b = Tkinter.Radiobutton(self.pp.frame)
      self.buttons.append(b)
      self.pp.addWidget(b, row=i, col=0)
      g = plot(rotMatrix[:,i], pack=0, parent=self.pp.frame)
      g.configure(title='PC-%i %3.2f%% of variance captured'%(i, variances[i]*100),
                  height=200, width=400)
      g.legend_configure(hide=1)
      self.pp.addWidget(g, row=i, col=1)
      s = Tkinter.Scale(self.pp.frame, 
                        from_= min([0, MLab.min(self.pcaDS.getColData(i))]),
                        to = max([MLab.max(self.pcaDS.getColData(i))]),
                        resolution = .1,
                        command = lambda value, component=i: self.sliderUpdate(value,component))
      self.scales.append(s)
      self.pp.addWidget(s, row=i, col=2)
    
  def sliderUpdate(self, value, component):
    """Update from sliders plot """

    vector = Numeric.dot(Numeric.array([s.get() for s in self.scales ]),
                         Numeric.transpose(self.pcaDS.matrix))
    self.nativePlot.element_configure('point', ydata=tuple(vector))
    
    xaxis = self.dm.getXColumn()
    yaxis = self.dm.getYColumn()
    self.pcaPlot.element_configure('point', xdata=(self.scales[xaxis].get(),), ydata=(self.scales[yaxis].get(),))
                                            
  def plotUpdate(self, event):
    """update from pca plot"""
    x,y = self.pcaPlot.invtransform (event.x, event.y)
    self.pcaPlot.element_configure('point', xdata=(x,), ydata=(y,))     
 
    ## FIX ME TO DEAL WITH PCA PLOTS OF DIMENSIONS OTHER THAN 1 AND 2
    xaxis = self.dm.getXColumn()
    yaxis = self.dm.getYColumn()
    
    vector = Numeric.dot(Numeric.transpose(Numeric.take(self.pcaDS.matrix, [xaxis, yaxis])), 
                         Numeric.array([x,y]))
    self.nativePlot.element_configure('point', ydata=tuple(vector))
    self.scales[xaxis].set(x)
    self.scales[yaxis].set(y)
    for i in range(len(self.scales)):
      if i not in [xaxis, yaxis]:
        self.scales[i].set(0)
  
  def __xChooserCB(self, dimString):
    self.dm.setXColumn(int(dimString[3]))
    self.pcaPlot.xaxis_configure(title=dimString)
    self.pcaPlot.plot()
  def __yChooserCB(self, dimString):
    self.pcaPlot.yaxis_configure(title=dimString)
    self.dm.setYColumn(int(dimString[3]))
    self.pcaPlot.plot()

# make old code chris era code work
pcaExplore = PCAExplorer
