########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

#
#  this module accept a dataset and create an xclust-sorted view.  The view
#  should look like a heat-map, and the rows should be sorted by an xclust
#  clustering of the dataset. I'm going to make the view in the same vein
#  as chris's IPlot tools, so that this becomes just another iplot plot.
#  

#
#  at some stage, I also want to add a flag that prints out the xclust
#  tree along with the sorted heat-map.  This shouldn't be hard to do, as
#  lucas has done most of the work already (see /compClust/util/dumpGTR.py)
#  I also want to add the ability to zoom into the xclust plot (very high
#  priority)
#

import colorsys

import MLab
import Pmw
import Tkinter
import types
import Numeric
import sys

from compClust.mlx.wrapper import XClust
from compClust.mlx import XClustTree
from compClust.mlx import views
from compClust.mlx import labelings
from compClust.util import unique

DEBUG=1

class heatMap:
  """
  This is simply a class that builds on the scrolled canvas class
  It requires a dataset, and optionally, an XClustTreeNode (=xtree)

  here's how to use this function

  foo = heatMap(ds)
  foo.createXTree(cluster_on = 'rows')
  foo.createXTree(cluster_on = 'cols')
  foo.sortDataset(sortRows=1)
  foo.sortDataset(sortRows=0)

  bar = heatMapPlot()
  bar.plot(foo)

  
  """

  def __init__(self, ds, 
               sorted_ds   = None,
               rowLabeling = None,
               colLabeling = None,
               sortby = 'both',
               **kw):

    # these are critical mappers.  only datamapper is required
    self.unsorted_ds  = ds
    self.sorted_ds    = sorted_ds

    if isinstance(rowLabeling, labelings.GlobalWrapper):
      rlab = rowLabeling.g
    else:
      rlab = rowLabeling

    if isinstance(colLabeling, labelings.GlobalWrapper):
      clab = colLabeling.g
    else:
      clab = colLabeling
      
    self.rowLabeling  = rlab
    self.colLabeling  = clab
    self._dataMapper  = XClustDataMapper(self,
                                         sortby = sortby,
#                                         rowLabeling=rowLabeling,
#                                         colLabeling=colLabeling
                                         )
    self._colorMapper = RowColorMapper(self)

  def _setPlot(self, plot):
    """
    setPlot(self)

    sets the trainsient plot pointer
    """
    self.__plot = plot
    
  def _getPlot(self):
    """
    setPlot(self)

    sets the trainsient plot pointer
    """

    return(self.__plot)
    
  def plotSetup(self):
    """
    plotSetup(self)

    this function gets called before any plotting begins
    """
    pass

  def getDataMapper(self):
    return self._dataMapper

  def setDataMapper(self, datamapper):
    self._dataMapper = datamapper

  def getColorMapper(self):
    return self._colorMapper

  def setColorMapper(self, colormapper):
    self._colorMapper = colormapper

  def getSortedDataset(self):
    return self.sorted_ds

  def setSortedDataset(self, ds):
    self.sorted_ds = ds

  def getUnsortedDataset(self):
    return self.unsorted_ds

  def setRowLabeling(self, labeling):
    self.rowLabeling = labeling

  def setColLabeling(self, labeling):
    self.colLabeling = labeling

  def getRowLabeling(self):
    return self.rowLabeling

  def getColLabeling(self):
    return self.colLabeling

class HeatMapPlot(Pmw.ScrolledCanvas):

  CanvasWidth = 500
  CanvasHeight = 700

  def __init__(self, master=None, **kw):
    #if master is None:
    #  master = Tkinter.Toplevel()
    if master is None:
      master = Tkinter.Tk()
      master.geometry( '%dx%d'%(self.CanvasWidth, self.CanvasHeight) )

    # ScrolledCanvas to Plot on
    Pmw.ScrolledCanvas.__init__(self, master, **kw)

    # mouseclick/key-pressing variables
    self.currentHeatMap = None
    self.__zoomSelection = [None, None, None, None]
    self.__dragging = None
    self.shiftPressed = 0

    # used for the 'select' function (and eventually for zooming functions)
    self.boundingRectangle = None

    # place to store the potentially-created subset
    self.ss = None

    # store heatmap-specific variables locally
    self.plotcolors = []
    self.selectedRows = []
    self.selectedColumns = []
    self.selectedPoints = []
    self.colors = []
    self.data = None
    self.ds   = None
    self.rowIDs = {}
    self.colIDs = {}


  def clearPlot(self):
    self.plotcolors = []
    self.selectedRows = []
    self.selectedColumns = []
    self.selectedPoints = []
    self.colors = []
    self.data = None
    self.ds  = None
    self.rowIDs = {}
    self.colIDs = {}

  def plot(self, heatMap=None, pack=1, plotColorMap=1):
    """
    this uses the heatMap interface to plot the data
    """

    if DEBUG:
      print 'start plot'
      sys.stdout.flush()
    replot=0
    if heatMap is None:
      heatMap = self.currentHeatMap
      replot=1

    
    if self.currentHeatMap:
      self.clearPlot()

    self.currentHeatMap = heatMap
    heatMap._setPlot(self)
    heatMap.plotSetup()
    
    if DEBUG:
      print 'getting mappers'
      print ' getting datamapper'
      sys.stdout.flush()
    # get the mappers

    dataMapper  = heatMap.getDataMapper()
    if DEBUG:
      print 'getting color mapper'
      sys.stdout.flush()
    colorMapper = heatMap.getColorMapper()

    if DEBUG:
      print 'getting data from datamapper'
      sys.stdout.flush()
    # extract the info first from the heatMap.  
    self.ds       = dataMapper.getPlotData()
    if DEBUG:
      print 'getting colors from colormapper'
      sys.stdout.flush()
    self.colors   = colorMapper.getColors('hsv')

    if DEBUG:
      print 'getting data'
      sys.stdout.flush()
    self.data     = self.ds.getData()

    # get dataset information
    if DEBUG:
      print 'getting numRows/numCols'
      sys.stdout.flush()
    numRows = self.ds.getNumRows()
    numCols = self.ds.getNumCols()

    # get plot information/constants
    if DEBUG:
      print 'getting plot information'
      sys.stdout.flush()
    columnLabelHeight = 100
    rowLabelWidth     = 300
    rectWidth         = 8
    rectHeight        = 12

    xAxisZero, yAxisZero, leftSpace, bottomSpace= [0,0,6,9]    
    x0 = xAxisZero + leftSpace
    y0 = yAxisZero + columnLabelHeight
    dx = rectWidth
    dy = rectHeight

    # these 4 commands very important for box-scrolling!!!
    self.dx = dx
    self.dy = dy
    self.columnLabelHeight = columnLabelHeight
    self.leftSpace         = leftSpace

    # I don't remember what these 4 variables do.
    maxYA=dy*numRows+y0+bottomSpace
    maxYAxis=maxYA
    maxXA=dx*numCols+x0+rowLabelWidth+0.2
    maxXAxis=maxXA

    # format plot colors
    if DEBUG:
      print 'formatting plot colors (loop over dataset)'
      sys.stdout.flush()
    self.plotcolors = MLab.zeros((numRows, numCols))
    self.plotcolors = self.plotcolors.tolist()
    for row in range(numRows):
      for col in range(numCols):
        self.plotcolors[row][col] = self.colors[ row*numCols + col ]
        

    if DEBUG:
      print 'drawing cells (loop over dataset)'
      sys.stdout.flush()
    for i in range(numRows):
      y = y0 + dy*i
      x = x0
      for j in range(numCols):

        # organize colors
        if self.plotcolors is not None:
          cellColor =  rgbToString(apply(colorsys.hsv_to_rgb,
                                         self.plotcolors[i][j]))
          #cellColor = 'black'
        else:
          cellColor = 'blue'

        # make dull all zero-values (we don't want to focus on them...)
        # this is not the most robust way to do this.  perhaps post-processing
        # is better.
        if self.data[i][j]==0 :
          cellColor = 'lightgray'

        # overwrite our plotColors, because we don't need the hsv values
        # anymore (I hope)
        self.plotcolors[i][j] = cellColor

        #self.create_rectangle(x, y,
        #                      x+dx, y+dy,
                              #outline='gray',
        #                      width=0.0,
        #                      fill = cellColor,
                              #tags=('NonZeroPoint',
                              #      '%d'%i,
                              #      '%d'%j,
                              #      '%s_%s'%(i,j)
                              #      )
        #                      )

        # this makes things a bit faster!
        self.create_line(x, y + dy/2,
                         x+dx, y + dy/2,
                         fill=cellColor,
                         width=str(dy)
                         )
        x = x + dx

    
    # drawing COLUMN LABELS
    if DEBUG:
      print 'drawing column labels'
      print '  getting col labels'
      sys.stdout.flush()

    try:
      print '******1'
      glob_clab = heatMap.getColLabeling()
      print '******2'
      clab = labelings.GlobalWrapper(self.ds, glabeling = glob_clab)
      print '******3'
      self.columnLabels = clab.getLabelByCols()
      print '******4'
    except:
      # warning:  this means that different sortings of the
      #           dataset will produce uncorrelated changes
      #           in the row and column labels
      self.columnLabels = map(str, range(numCols))

    if DEBUG:
      print 'printing column labels'
      sys.stdout.flush()
    y=y0-columnLabelHeight            
    x=x0
    for j in range(numCols):
      r=len(self.columnLabels[j])
      if r>7:
        r=7
      columnLabel = ""
      for k in range(r):
        columnLabel = columnLabel + self.columnLabels[j][k]+'\n'
            
      id = self.create_rectangle(x, y,
                                 x + dx, y + columnLabelHeight - (r-1),
                                 outline='black',
                                 width=1, fill='gray',
                                 tags=('ColumnLabel','%d'%j,
                                       'Col_%s'%j)
                                 )
      self.create_text(x+dx/2,
                       y,
                       text=columnLabel, anchor=Tkinter.N,
                       fill='black',
                       font='Times 10',
                       tags=('ColumnLabel', '%d'%j,'text'))
      x = x + dx
      self.colIDs[j] = id

    # drawing Gene labels
    if DEBUG:
      print 'printing gene labels'
      print '  getting row labels'
      sys.stdout.flush()

    try:
      glob_rlab = heatMap.getRowLabeling()
      rlab = labelings.GlobalWrapper(self.ds, glabeling=glob_rlab)
      self.rowLabels = rlab.getLabelByRows()
    except:
      # warning:  this means that different sortings of the
      #           dataset will produce uncorrelated changes
      #           in the row and column labels
      self.rowLabels = map(str, range(numRows))
    
    x=maxXA-rowLabelWidth            
    y=y0

    if DEBUG:
      print 'plotting row labels'
      sys.stdout.flush()
    for i in range(numRows) :
      id = self.create_rectangle(x, y,
                                 x+rowLabelWidth,
                                 y+dy,
                                 outline='gray',
                                 width=2, fill='gray',
                                 tags=('RowLabel','%d'%i,
                                       'Row_%s'%i)
                                 )
      self.create_text( x,
                        y+dy/2, anchor=Tkinter.W,
                        text=self.rowLabels[i][:75],
                        fill='black',
                        font='Times 10',
                        tags=('RowLabel', '%d'%i, 'text'))
      y=y+dy
      self.rowIDs[i] = id


    if DEBUG:
      print 'binding tags'
      sys.stdout.flush()
    # store lists of selected points/rows/cols
    self.selectedPoints  = []
    self.selectedRows    = []
    self.selectedColumns = []

    # create bindings
    self.tag_bind('ColumnLabel', '<3>', self.columnClick)
    self.tag_bind('RowLabel', '<3>', self.rowClick)
    self.tag_bind('ColumnLabel', '<1>', self.columnSelect)      
    self.tag_bind('RowLabel', '<1>', self.rowSelect)

    #self.tag_bind('NonZeroPoint', '<1>', self.nonZeroClick)
    #self.tag_bind('NonZeroPoint', '<1>', self.nonZeroPointSelect)

    #self.bind(sequence="<Shift-ButtonPress-1>",   func=self.mouseDown)
    
    #self.component('canvas').bind('<KeyPress>', func = self.checkShiftPressed)
    #self.component('canvas').bind('<KeyRelease>', func = self.checkShiftReleased)

    self.component('canvas').bind('<KeyPress-Shift_L>', func = self.checkShiftPressed)
    self.component('canvas').bind('<KeyRelease-Shift_L>', func = self.checkShiftReleased)

    self.component('canvas').bind("<1>",   func=self.mouseDown)
    self.component('canvas').bind("<ButtonRelease-1>", func=self.selectCells)


      #      self.canvas.bind('<3>',
      #                  lambda e, s=self: s.canvas.scan_mark(e.x, e.y))
      #      self.canvas.bind('<B3-Motion>',
      #                  lambda e, s=self: s.canvas.scan_dragto(e.x, e.y))

    ###########SELECT FUNCTIONS <3> ################  


    ##########create subsetting buttons##################

    if DEBUG:
      print 'packing'
      sys.stdout.flush()
    if pack:
      self.createButtons()

      if DEBUG:
        print 'calling self.pack'
        sys.stdout.flush()
      self.pack(expand='yes', fill='both')

      # fixes scrollbars
      self.configure(hscrollmode='static', vscrollmode='static')

      self.resizescrollregion()
      self.component('canvas').focus_set()
      
      if plotColorMap:
        if DEBUG:
          print 'plotColorMap'
          sys.stdout.flush()
        self.plotColorMap(MLab.min(self.colors)[0], MLab.max(self.colors)[0])
      if DEBUG:
        print 'done'
        sys.stdout.flush()


  # button creation and associated functions
  def createButtons(self):
    self.subsetRows = Tkinter.Button(self.interior(), text = 'subsetRows',
                                     command = self.subsetRows)
    self.subsetCols = Tkinter.Button(self.interior(), text = 'subsetCols',
                                     command = self.subsetCols) 
    self.subset     = Tkinter.Button(self.interior(), text = 'subsetBoth',
                                     command = self.subset)
    self.subsetRows.pack(anchor = 'ne')
    self.subsetCols.pack(anchor = 'ne')
    self.subset.pack(anchor = 'ne')

  def subsetRows(self):
    rows = map(lambda x:x[0], self.selectedRows)
    ss = views.RowSubsetView(self.ds, rows)
    self.ss = ss
    self.newHeatMap(self.ss)

  def subsetCols(self):
    cols = map(lambda x: x[0], self.selectedColumns)
    ss = views.ColumnSubsetView(self.ds, cols)
    self.ss = ss
    self.newHeatMap(self.ss)
    
  def subset(self):
    rows = map(lambda x:x[0], self.selectedRows)
    cols = map(lambda x: x[0] + self.ds.getNumRows(), self.selectedColumns)
    ss = views.SubsetView(self.ds, rows + cols)
    self.ss = ss
    self.newHeatMap(self.ss)

  def newHeatMap(self, ds):

    rl = self.currentHeatMap.getRowLabeling()
    cl = self.currentHeatMap.getColLabeling()
    hm = heatMap(ds, rowLabeling = rl, colLabeling = cl,
                 sortby='none')
    cm = hm.getColorMapper()
    cm.setColorByValue()
    
    hp = HeatMapPlot()
    hp.plot(hm, plotColorMap=0)
    ds.hm = hm

#  def propogateGLabeling(self, ds, globalLabeling):
#    glab = globalLabeling.g
#    lab = labelings.GlobalWrapper(ds, glabeling=glab)
#    return lab


  # canvas button-click / key-press functions
  def checkShiftPressed(self, event):
    self.shiftPressed = 1

  def checkShiftReleased(self, event):
    self.shiftPressed = 0

  def nonZeroPointSelect(self, event):
    """
    not currently being used
    """
    self.clearSelectedRows()
    self.clearSelectedColumns()
    
    a=self.gettags( Tkinter.CURRENT)        
    id=self.find_withtag(Tkinter.CURRENT)[0]
    element=(1,id,  int(a[1]), int(a[2]) )
    if self.selectedPoints.count(element)==0:
      self.itemconfigure(id, fill='gray20')
      self.selectedPoints.append(element)
    else:
      self.clearSelectedPoint(element)

  def nonZeroClick(self, event):
    """
    not currently being used
    """
    a=self.gettags( Tkinter.CURRENT)
    #columnLabel=self.dataset.getLabeling(
    #  self.primaryColumnLabelingName).getLabelByCol(int(a[2]))
    #rowLabel=self.dataset.getLabeling(
    #  self.primaryRowLabelingName).getLabelByRow(int(a[1]))
    #self.mainRoot.activatePlotMember(str((columnLabel,rowLabel)),1)

  def clearSelected(self):
    self.clearSelectedColumns()
    self.clearSelectedRows()
        
  def columnSelect(self, event):
    self.clearSelectedPoints()
    a=self.gettags( Tkinter.CURRENT)
        
    id=self.find_withtag(Tkinter.CURRENT)[0]
    if  'text' in a:
      id=id-1
    element= (int(a[1]), id)
    if self.selectedColumns.count(element)==0:
      self.itemconfigure(id, fill='white')
      self.selectedColumns.append(element)
    else:
      self.clearSelectedColumn(element)
            
  def zeroPointSelect(self, event):
    """
    not currently being used
    """
    self.clearSelectedRows()
    self.clearSelectedColumns()
    a=self.gettags( Tkinter.CURRENT)        
    id=self.find_withtag(Tkinter.CURRENT)[0]
    element=(0,id,  int(a[1]), int(a[2]) )
    if self.selectedPoints.count(element)==0:
      self.itemconfigure(id, fill='gray40')
      self.selectedPoints.append( element)
    else:
      self.clearSelectedPoint(element)

  def nonZeroPointSelect(self, event):
    """
    not currently being used
    """
    self.clearSelectedRows()
    self.clearSelectedColumns()
        
    a=self.gettags( Tkinter.CURRENT)        
    id=self.find_withtag(Tkinter.CURRENT)[0]
    element=(1,id,  int(a[1]), int(a[2]) )
    if self.selectedPoints.count(element)==0:
      self.itemconfigure(id, fill='gray20')
      self.selectedPoints.append(element)
    else:
      self.clearSelectedPoint(element)

  def rowSelect(self, event):
    self.clearSelectedPoints()
    a=self.gettags( Tkinter.CURRENT)
        
    id=self.find_withtag(Tkinter.CURRENT)[0]
    if  'text' in a:
      id=id-1
    element= (int(a[1]), id)
    if self.selectedRows.count(element)==0:
      self.itemconfigure(id, fill='white')
      self.selectedRows.append(element)
    else:
      self.clearSelectedRow(element)
                
  def clearSelectedRow(self, element):
    self.itemconfigure(element[1], fill='gray')
    self.selectedRows.remove(element)

  def clearSelectedColumn(self, element):
    self.itemconfigure(element[1], fill='gray')
    self.selectedColumns.remove(element)
        
  def clearSelectedRows(self):
    for el in self.selectedRows:
      self.itemconfigure(el[1], fill='gray')
    self.selectedRows=[]

  def clearSelectedColumns(self):
    for el in self.selectedColumns:
      self.itemconfigure(el[1], fill='gray')
    self.selectedColumns=[]

  def clearSelectedPoint(self,element):        
    if element[0]==0 :
      self.itemconfigure(element[1], fill=self['background'])
    else:
      i=element[2]
      j=element[3]
      if self.data[i][j] == 0:
        color = 'lightgray'
      else:
        color = self.plotcolors[i][j]
      self.itemconfigure(element[1], fill=color)
    self.selectedPoints.remove(element)


  def clearSelectedPoints(self):
    #format of selectedPoints :
    #[ type of point (0 or 1), id, i-coord, j-coordinates] 
    bg = self.component('canvas')['background']
    for el in self.selectedPoints:
      if el[0]==0 :
        self.itemconfigure(el[1], fill=bg)
      else:
        i=el[2]
        j=el[3]
        if self.data[i][j] == 0:
          color = 'lightgray'
        else:
          color = self.plotcolors[i][j]

        self.itemconfigure(el[1], fill=color)

    self.selectedPoints=[]


        
    #######################################################################    
    ####Click functions <1>################################################
        
  def columnClick(self, event):
    a=self.gettags( Tkinter.CURRENT)
    #print "LIGAND ---- %s"%a[1]
    ligandName= self.columnLabels[int(a[1])]
    sys.stdout.write('Tissue ---- %s\n'%a[1])
    sys.stdout.flush()
    #print ligandName
    #id = self.canvas.find_withtag(CURRENT)[0]
    #if 'text' in self.canvas.gettags(CURRENT):
    #    id = id-1
    #self.canvas.itemconfigure(id, fill='SeaGreen1')

  def rowClick(self, event):

    a=self.gettags( Tkinter.CURRENT)
    sys.stdout.write("GENE ---- %s\n"%a[1])
    sys.stdout.flush()
    geneName  = self.rowLabels[int(a[1])]
    #print geneName
      




  def getColor(self, value):
    """
    for now, the range of values runs between 5 and 10 (roughly).  I want a
    colormapper that returns colors in the red-green spectrum when given values
    between the range.  also, when values extend beyond range, floor them to the
    max/min range values
    """
    
    if value > 9:
      color='red'
    if value <= 9 and value >=7:
      color='yellow'
    if value < 7:
      color='green'

    return color

  def mouseDown(self, event):

    if self.boundingRectangle:
      self.delete(self.boundingRectangle)

    #a=self.gettags( Tkinter.CURRENT)
        
    #id=self.find_withtag(Tkinter.CURRENT)[0]
    #if  'text' in a:
    #  id=id-1
    
    #element= (int(a[1]), id)
    
    x0, y0, x1, y1 = self.__zoomSelection
    self.__dragging = 0
    #if self.inside(event.x, event.y):
    (x0, y0) = (self.canvasx(event.x), self.canvasy(event.y))
    # self.__zoomSelection = [x0,y0,x1, y1]
    #  self.__dragging = 1
    
    self.__zoomSelection = [x0,y0,x1, y1]
    #self.marker_create("line", name="marking rectangle", dashes=(2, 2))
    self.component('canvas').bind("<B1-Motion>",  func=self.mouseDrag)

    # draw the bounding rectangle!
    self.boundingRectangle = self.create_rectangle(x0,y0,x0+self.dx,y0+self.dy)
    self.lastx = self.startx = self.canvasx(event.x)
    self.lasty = self.starty = self.canvasy(event.y)

  def mouseDrag(self, event):
    self.__dragging = 1

    self.lastx = self.canvasx(event.x)
    self.lasty = self.canvasy(event.y)

    if self.boundingRectangle:
      self.delete(self.boundingRectangle)
      self.boundingRectangle = self.create_rectangle(self.startx,
                                                     self.starty,
                                                     self.lastx,
                                                     self.lasty)
      #self.selectCells(event)

  def selectCells(self, event):
    x0, y0, x1, y1 = self.__zoomSelection
    x1,y1 = self.canvasx(event.x), self.canvasy(event.y)

    if self.__dragging:
      self.component('canvas').unbind(sequence="<Motion>")
      #self.marker_delete("marking rectangle")
      sys.stdout.write('x0:%s, x1:%s, y0:%s, y1:%s\n'%(x0,x1,y0,y1))
      sys.stdout.flush()
      

      if x0 <> x1 and y0 <> y1:

        # make sure the coordinates are sorted
        if x0 > x1:
          x0, x1 = x1, x0
        if y0 > y1:
          y0, y1 = y1, y0


        leftCol   = int((x0 - self.leftSpace)/self.dx) + 1
        rightCol  = int((x1 - x0)/self.dx) + leftCol
        topRow    = int((y0 - self.columnLabelHeight)/self.dy)
        bottomRow = int((y1 - y0)/self.dy) + topRow

        print 'bottomRow' + str(int((y1-y0/self.dy)))

        sys.stdout.write('L%d R%d T%d B%d'%(leftCol, rightCol, topRow, bottomRow))
        sys.stdout.flush()

        # clear all selected points (if shift isn't being held down)
        if self.shiftPressed == 0:
          self.clearSelected()


        for i in range(topRow, bottomRow):
          # idr (below) is NOT a tuple-list of one element (as is what is
          # returned by get_tags function.  Instead, idr (as idc below) is
          # a single integer

          # since itemconfigure requires the first element be a tuple (of items)
          # i'm placing idr into a tuple-list of 1 element (again, same as idc
          # below)
          idr = self.rowIDs.get(i)
          if idr:
            row_element = (i,idr)
            self.itemconfigure((idr,), fill='white')
            self.selectedRows.append(row_element)
        for j in range(leftCol, rightCol):
          #idc = self.find_withtag('Col_' + str(j))
          idc = self.colIDs.get(j)
          sys.stdout.flush()
          if idc:
            col_element = (j, idc)
            self.itemconfigure((idc,), fill='white')
            self.selectedColumns.append(col_element)


        # this is slow, and I don't need it for subsetting

        #for i in range(topRow, bottomRow):
        #  for j in range(leftCol, rightCol):
        #    id = self.find_withtag(str(i)+ '_' + str(j))
        #    if id:
        #      element = (1,id[0], i, j)
        #      self.selectedPoints.append(element)
            

    self.__zoomSelection = [None,None,None,None]        
    return

  def plotColorMap(self, min, max):

    print min
    print max
    sys.stdout.flush()
    
    width = 300
    height = 30
    resolution = 100

    x0=0
    y0=0
    dx=width/resolution
    dy=height/3

    self.colorMapPlot = Tkinter.Canvas(master = Tkinter.Tk())
    tl = self.colorMapPlot.winfo_toplevel()
    tl.geometry ( '%dx%d'%(width,height) )

    d_color = (max-min)/100
    for i in range(101):
      x = x0 + i*dx
      y = y0
      colorVal = min + d_color*i
      colorVal = int(colorVal*100)/100.0
      color = rgbToString(colorsys.hsv_to_rgb(colorVal,1,1))
      self.colorMapPlot.create_rectangle(x,y,
                                         x+dx,y+dy,
                                         fill=color,
                                         outline=color,
                                         width=0
                                         )

      if i%20 == 0:
        self.colorMapPlot.create_text(x,y+dy*1.6,
                                      text=colorVal,
                                      anchor = Tkinter.N,
                                      fill='black',
                                      font='Times 10'
                                      )
    self.colorMapPlot.pack(fill='both', expand='yes')


class Mapper:

  """
  Base class for PlotView Mapper Classes.
  """

  def __init__(self, heatMap):
    self.__heatMap = heatMap

  def getHeatMap(self):
    return(self.__heatMap)

class RowColorMapper(Mapper):
  """
  allows you to dynamically adjust color mapper.

  """
  def __init__(self, heatMap):

    Mapper.__init__(self, heatMap)
    self.dataset = self.getHeatMap().getSortedDataset()
    datalen = self.dataset.getNumRows() * self.dataset.getNumCols()

    # these are teh HSV lists that are used to color plots
    self.__h = None
    self.__s = None
    self.__v = None

    if self.__h == None:
      #self.__h = map(lambda x: 0, range(0,datalen))
      self.__h = MLab.zeros(datalen) 
    if self.__s == None:
      #self.__s = map(lambda x: 1, range(0,datalen))
      self.__s = MLab.ones(datalen) 
    if self.__v == None:
      #self.__v = map(lambda x: 1, range(0,datalen))
      self.__v = MLab.ones(datalen) 


  def setColorRange(self, minValue = None, maxValue = None):
    """
    setColorRange(self, minValue = None, maxValue = None, component = 'h')

    sets the min/max range of colors on the colorwheel to be the min/max range
    of colors used in the rectangles of the heatmap

    """

    #values = scaleList(minValue, maxValue)

    #if component == 'h':
    #  self.__h = values
    #
    pass

  def setColorByValue(self, component = 'h', minValue=None,
                      maxValue=None, colorRange=(0,0.4)):
    
    """
    setHueByColValue(self,col, hRange(0,1))

    sets the color to be a function of the value for that row at a given column.
    """

    self.dataset = self.getHeatMap().getSortedDataset()
    data = self.dataset.getData()
    data = MLab.ravel(data)

    if maxValue is None:
      maxValue = MLab.max(data)
    if minValue is None:
      minValue = MLab.min(data)

    print 'minValue: ' + str(minValue)
    print 'maxValue: ' + str(maxValue)
    sys.stdout.flush()
    
    values = scaleList(data, minValue=minValue, maxValue=maxValue,
                       minReturn=colorRange[0], maxReturn=colorRange[1])
    if component   == 'h':
      self.__h = values
    elif component == 's':
      self.__s = values
    elif component == 'v':
      self.__v = values

  def getColors(self, colorModel='hsv'):
    assert colorModel == 'hsv'
    
    self.dataset = self.getHeatMap().getSortedDataset()
    #data = self.dataset.getData()
    #data = MLab.ravel(data)
    datalen = self.dataset.getNumRows() * self.dataset.getNumCols()

    #return(zip(self.__h, self.__s, self.__v))
    return Numeric.transpose(Numeric.array((self.__h, self.__s, self.__v)))

class IDataMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how a dataset is projected onto the graph
  """

  def getPlotData(self):

    """
    plotData(self)
    
    returns a list of X,Y tuple Tuples.

    example: 
    [ ((x1,x2,x3,x4), (y1,y2,y3,y4))
      ((x1,x2,x3,x4), (y1,y2,y3,y4)) ]
    
    """
    raise NotImplementedError()
  

  def getLineTrace(self):
    """
    getLineStyle(self)
    
    retuns how the points when be connected...  If None, then the BLT
    default line connectors are used.

    Must be:
      'increasing', 'decreasing', 'both' or  None
    """

    return(None)
  

  def getWeightsAndSytles(self):

    """
    getWeights (self)

    This allows for marker control on a datapoint by datapoint
    basis...  If you want control on an element by element basis use
    getPlotMarkes and getMarkerSize functions inside of the
    MarkersMappers.
    
    returns either None or a list of tuples equal in length to the
    plotData returns value where the tuple is the weights tuple and
    the cooresponding styles tuple (see the BLT documentation for more
    information (this is very powerful)

    hm.createXTree (cluster_on='rows')
    hm.sortDataset (sortRows=1)
    hm.createXTree (cluster_on='cols')
    hm.sortDataset (sortRows=0)
    cm = hm.getColorMapper()
    cm.setColorByValue()

    """

    return(None)
  

class XClustDataMapper(IDataMapper):
  """
  heatPlotmapper():

  a datamapper that plots each row of the dataset as a set of colored
  rectangles

  sortby = 'rows', 'columns' or 'both'
  """

  def __init__(self, heatMap, sortby = 'rows',
               row_xtree   = None,
               col_xtree   = None,
               rowLabeling = None,
               colLabeling = None):
    Mapper.__init__(self, heatMap)

    self.unsorted_ds = heatMap.getUnsortedDataset()
    self.sorted_ds = None
    self.row_xtree = row_xtree
    self.col_xtree = col_xtree
    
    self.sorted_ds = heatMap.getSortedDataset()
    self.unsorted_ds = heatMap.getUnsortedDataset()

    if not self.sorted_ds:
      self.sorted_ds = views.SortedView(self.unsorted_ds)

    if sortby == 'rows':
      self.createXTree(cluster_on=sortby)
      self.sortDataset(sortRows=1)
    elif sortby == 'columns':
      self.createXTree(cluster_on=sortby)
      self.sortDataset(sortRows=0)
    elif sortby == 'both':
      self.createXTree(cluster_on='rows')
      self.sortDataset(sortRows=1)
      self.createXTree(cluster_on='columns')
      self.sortDataset(sortRows=0)
    else:
      pass

    heatMap.setSortedDataset(self.sorted_ds)
        
    self.__heatMap = heatMap
    #self.setRowLabeling(rowLabeling)
    #self.setColLabeling(colLabeling)

  def setHeatMap(self, heatMap):
    self.__heatMap = heatMap

  def getHeatMap(self):
    return self.__heatMap
  

  def createXTree(self, cluster_on = 'rows'):
    run = 0
    if cluster_on == 'rows':
      if not self.row_xtree:
        run=1
    if cluster_on == 'columns':
      if not self.col_xtree:
        run=1

    if run:
      self.runXClust(cluster_on = cluster_on)
    else:
      print 'XClustering of ' + cluster_on + ' already exists' 
    return

  def removeRow_xtree(self):
    self.row_xtree = None

  def removeCol_xtree(self):
    self.col_xtree = None

  def remove_xtrees(self):
    self.row_xtree = None
    self.col_xtree = None

  def runXClust(self, cluster_on = 'columns'):
    pars = {'agglomerate_method': 'none',
            'cluster_on': cluster_on,
            #'cluster_on': 'rows',
            #'clustering_input_filename': '/tmp/XClus1RBANe/@3007.2.tmp',
            'distance_metric': 'euclidean',
            'k': 10,
            'results_dir': '.',
            'save_intermediate_files': 'yes',
            'save_intermediate_files_base': 'Xclust_' + cluster_on,
            'transform_method': 'none'}

    if cluster_on == 'rows':
      xcl =  XClust(dataset = self.unsorted_ds, parameters = pars)
      suffix = '.gtr'
    else:
      #xcl =  XClust(dataset = views.TransposeView(self.unsorted_ds),
      #              parameters = pars)
      xcl = XClust(dataset = self.unsorted_ds, parameters=pars)
      suffix = '.atr'

    
    xcl.run()
    xtree = XClustTree.XClustTree()
    xtree.read('Xclust_' + cluster_on + suffix)


    if cluster_on == 'rows':
      self.row_xtree = xtree
      self.row_parameters = xcl.getParameters()
    else:
      self.col_xtree = xtree
      self.col_parameters = xcl.getParameters()

  def sortDataset(self, sortRows=1):
    if sortRows:
      self.row_leaves = self.getLeaves(self.row_xtree,'GENE')
    else:
      self.col_leaves = self.getLeaves(self.col_xtree,'ARRY')

    # create sorted Dataset if needed
    if not self.sorted_ds:
      self.sorted_ds = views.SortedView(self.unsorted_ds)

    # permute rows/cols as required
    if sortRows:
      self.sorted_ds.permuteRows(self.row_leaves)
    else:
      self.sorted_ds.permuteCols(self.col_leaves)

    return self.sorted_ds

  def getLeaves(self, tree, prefix='GENE'):
    # get leaves in sorted order
    #iter = self.xtree.iterator()
    iter = tree.iterator()
    leaves = []
    node = iter.next()
    while node != None:
      if tree.isLeaf(node.key()):
        #print 'new leaf'
        leaves.append(node.key())
        
      node = iter.next()

    # clean leaveslist (GENE188X --> 188)  or (ARRY33X --> 33)
    
    print 'cleaning leaf-list'
    leaves = map(lambda x: int(x.split(prefix)[1].split('X')[0]), leaves)
    return leaves

  def setRowLabeling(self, labeling=None):
    if labeling is None:
      self.__rowvalues = tuple(range(0, self.sorted_ds.getNumRows()))
    else:
      self.__rowvalues = labeling.getLabelByRows()

  def setColLabeling(self, labeling=None):
    if labeling is None:
      self.__colvalues = tuple(range(0, self.sorted_ds.getNumCols()))
    else:
      self.__colvalues = labeling.getLabelByCols()
    
  def getRowLabels(self):
    return self.__rowvalues

  def getColLabels(self):
    return self._colvalues

  def getPlotData(self):
    """
    gets the sorted dataset.  Colors will be added to this dataset
    by the colormapper
    """

    ds = self.getHeatMap().getSortedDataset()
    #data = []
    #for row in range(ds.getNumRows()):
    #  data.append((self.__rowvalues, tuples(ds.getRowData(row))))

    return ds
    
    
class heatPlotMapper(Mapper):
  """
  heatPlotmapper():

  a datamapper that plots each row of the dataset as a set of colored
  rectangles
  """
  def __init__(self, heatMap, rowLabeling=None, colLabeling=None):
    Mapper.__init__(self, heatMap)
    
    self.sorted_ds = heatMap.getSortedDataset()
    self.__heatMap = heatMap
    #self.setRowLabeling(rowLabeling)
    #self.setColLabeling(colLabeling)
    
  def setRowLabeling(self, labeling=None):
    if labeling is None:
      self.__rowvalues = tuple(range(0, self.sorted_ds.getNumRows()))
    else:
      self.__rowvalues = labeling.getLabelByRows()

  def setColLabeling(self, labeling=None):
    if labeling is None:
      self.__colvalues = tuple(range(0, self.sorted_ds.getNumCols()))
    else:
      self.__colvalues = labeling.getLabelByCols()
    
  def getRowLabels(self):
    return self.__rowvalues

  def getColLabels(self):
    return self._colvalues

  def getPlotData(self):
    ds = self.getHeatMap().getSortedDataset()
    #data = []
    #for row in range(ds.getNumRows()):
    #  data.append((self.__rowvalues, tuples(ds.getRowData(row))))

    return ds
    

#
# helper functions
#


def rgbToString(rgb):

  """
  rgbToString(rgb):
  
  turns an rgb tuple into a hex string
  """
  colorString = ''
  for hue in rgb:
    tmpString = hex(int(hue*255))[2:]
    if len(tmpString) <2:
      tmpString = "0%s"%(tmpString)
    colorString += tmpString
  color = "#%s"%(colorString)
  return color


def scaleList(values, minValue = None, maxValue= None, minReturn=0,
              maxReturn=1):
  """
  for now, the range of values runs between 5 and 10 (roughly).  I want a
  colormapper that returns colors in the red-green spectrum when given values
  between the range.  also, when values extend beyond range, floor them to the
  max/min range values
  """

  labels =  unique.unique (values)
  numericLabels = 1
  for label in labels:
    labelType = type(label)
    if labelType not in [types.IntType, types.LongType, types.FloatType]:
      numericLabels = 0
      break

  if numericLabels:
    try:
      labels = Numeric.array(labels)
      data = Numeric.array(values)
    except:
      numericLabels = 0

  datalist = data.tolist()
  if numericLabels:
    if minValue is None:
      #minValue = MLab.min(labels)

      # let minValue be the bottom 10-percentile
      minValue = data[len(datalist)*.1]
      
    if maxValue is None:
      #maxValue = MLab.max(labels)
      
      # let maxValue be the top 10-percentile
      maxValue = data[len(datalist)*.9]

  else:
    minValue = 0
    maxValue = len(labels)
    labelMap = {}
    for label, count in zip(labels, range(maxValue)):
      labelMap[label]= count
    data  = map(lambda x: labelMap[x], values)

  # perhaps there is a more elegent way of doing this... But this
  # works to scale the data between minValue and maxVale
  data = Numeric.array(data)
  data = Numeric.array(map(lambda x: max(x,0), data - minValue))
  data = data.astype('f')
  data = data / (maxValue-minValue)
  data = Numeric.array(map(lambda x: min(x, 1), data))
  
  # thie maps the data into the minReturn..maxReturn range.
  data = (data * (maxReturn - minReturn)) + minReturn

  return data
