import colorsys
import matplotlib.numerix as nx
from matplotlib.numerix import mlab

#import TrajectorySummary

from PlotPage import PlotPage
from TrajectorySummary import TrajectorySummary

from IPlot import IPlot, DatasetPlot

from compClust.mlx.views import AggregateFunctionView

from compClust.mlx.labelings import Labeling
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import GlobalWrapper
from compClust.mlx.labelings import subsetByLabeling

from compClust.score import roc
from compClust.score.ConfusionMatrix2 import  ConfusionMatrix

from compClust.util.InterpreterTools import safeStdDev

class ConfusionMatrixSummary(PlotPage):

  """
  Generates a trajectory summary for every cell in a confusion
  matrix
  
  """
  def __init__(self, canvasFactory, dataset, labeling1, labeling2, primaryLabeling=None, secondaryLabeling=None,l1Order=None, l2Order=None, web_safe=False, parent=None):
    """
    __init__(self, dataset, labeling1, labeling2, parent=None)
    
    web_safe makes a labeling that is just a string which is needed when the labels
    are shoved into the URL
    
    """
    computeROC=0
    # rows, cols = number of unique labels + 1 for the summary
    numRows = len(labeling1.getLabels())+1
    numCols = len(labeling2.getLabels())+1
    PlotPage.__init__(self, numRows, numCols, canvasFactory=canvasFactory)
   
    # sets up the needed information
    self.dataset = dataset
    self.labeling1 = labeling1
    self.labeling2 = labeling2
    self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling
    self.l1Order = l1Order
    self.l2Order = l2Order
    
    # order the data by cluster sizes
    if self.l1Order is None:
      l1RowLabels = self.labeling1.getLabelByRows()
      l1Order = map(lambda x: (l1RowLabels.count(x), x), self.labeling1.getLabels())
      l1Order.sort()
      l1Order = map(lambda x: x[1], l1Order)
      l1Order.reverse()
      self.l1Order = l1Order
    else:
      l1Order = self.l1Order
      l1Order.reverse()

    if self.l2Order is None:
      l2RowLabels = self.labeling2.getLabelByRows()
      l2Order = map(lambda x: (l2RowLabels.count(x), x), self.labeling2.getLabels())
      l2Order.sort()
      l2Order = map(lambda x: x[1], l2Order)
      self.l2Order = l2Order
    else:
      l2Order = self.l2Order
    
    # stores annotation labels computed by draw for annotate step
    self.annotations = {} 

    clusterOrders = [l1Order, l2Order]
    self.cm = ConfusionMatrix([self.labeling1, self.labeling2], 
                              clusterOrders=clusterOrders,
                              web_safe=web_safe)
    self.title("Confusion Matrix- NMI= %3.2f, NMI'= %3.2f, LA = %3.2f"%(self.cm.NMI(),
                                                                        self.cm.transposeNMI(),
                                                                        self.cm.linearAssignment()))
    # this is a labeling of confusion matrix coord
    self.labeling = self.cm.getConfusionLabeling()
    
    ### This deals with aggregating the datasets.   
    #keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    #self.means = AggregateFunctionView(self.dataset, keylists, mlab.mean)
    #self.stds  = AggregateFunctionView(self.dataset, keylists, safeStdDev)
    
    self.createMeanStdDataAndLabelings(self.labeling)
      
    
    # we have to recreate the keylists because it may have been modified in the AggregateFunctionView
    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    mapping   = [self.means._mapKeysToParent([key]) for key in  self.means.getRowKeys ()] 
    positions = [mapping.index(keylist) for keylist in keylists]
    
    self.aggregateOrder= {}
    map(self.aggregateOrder.setdefault, self.labeling.getLabels(), positions)

    # draw the summarys
    self.drawConfusionMatrix(computeROC)
    self.setUniformYAxis()
    self.annotateConfusionMatrix()
    self.show()

    # these are pointers to the full plot DatasetPlotter and its PlotView
    self.selectedPlot = None
    self.selectedView = None

  def annotateConfusionMatrix(self):
    """Add annotation information to the confusion Matrix cells

    This has to be done after we adjust all the axis sizes.
    (Otherwise the text ends up next to the data instead of along the
    edges of the plots.)
    """
    # iterate over all the plots determining what kind of plot it is
    # and then add the appropriate type of summary information.
    for index, axis in self.plots.items():
      x, y = index
      if x == 0 or y == self.numCols-1:
        self.annotateConfusionMatrixMargin(axis)
      else:
        self.annotateConfusionMatrixCell(axis)

  def annotateConfusionMatrixCell(self, axis):
    """Add annotation to the core confusion matrix cells
    """
    xmin, xmax = axis.viewLim.intervalx().get_bounds()
    ymin, ymax = axis.viewLim.intervaly().get_bounds()

    label = self.annotations.get(axis, None)
    if label is not None:
      axis.text(xmin, ymax, label, verticalalignment = 'top')

  def annotateConfusionMatrixMargin(self, axis):
    """Add annotation to the row and column margin plots
    """
    # plot labels
    xmin, xmax = axis.viewLim.intervalx().get_bounds()
    ymin, ymax = axis.viewLim.intervaly().get_bounds()

    labels = self.annotations.get(axis, None)
    if labels is not None:
      data_text, label_text = labels
      axis.text(xmin, ymax, data_text, verticalalignment='top')
      axis.text(xmax, ymin, label_text, horizontalalignment = 'right')

  def drawConfusionMatrix(self, computeROC=1, cmColorMap=None):
    """
    drawConfusionMatrix(self, computeROC=1)

    This method draws all the cells of the confusion matrix with small
    summary tragetories in each one.
    """
    if cmColorMap is None:
      cmColorMap = self.getColorByColumnMarginals

    l1Order, l2Order = self.sortLabelingsByClusterSize(self.l1Order,
                                                       self.l2Order)
    self.drawConfusionMatrixCells(l1Order, l2Order, cmColorMap)
    self.drawConfusionMatrixColumnSummary(l2Order)
    self.drawConfusionMatrixRowSummary(l1Order)

  def drawConfusionMatrixCells(self, l1Order, l2Order, cmColorMap):
    """Draw the confusion matrix cells (the things that aren't the row
    or column summary values)
    """
    from compClust.score import ConfusionMatrix2
    color_map = cmColorMap()
    total_rows_summarized = 0
    # now draw the pretty confusion matrix.
    gridRow = 1 # keep track of position in the Tk grid
    gridCol = 0
    for lab1 in l1Order:
      for lab2 in l2Order:
        # since we ne need to convert the cell labels to pure strings when in 
        # web mode we need to convert the tuple label we create back to the 
        # strings that we attached to the variouc confusion matrix cells
        cell_label = (lab1, lab2)
        if self.cm.isWebSafe:
          cell_label = ConfusionMatrix2.tuple_stringify(cell_label)
          
        dataRow = self.meansLab.getRowsByLabel(cell_label)
        if len(dataRow) > 0:
          dataRow = dataRow[0]
          color = color_map[gridRow][gridCol]
          axis = self.addPlot(gridRow, gridCol, axisbg=color)
          rows_summarized = self.drawSummaryPlot(axis, dataRow, computeROC=0, titles=0, verbose=0)
          total_rows_summarized += len(rows_summarized)
          len(rows_summarized)
          axis.set_xticklabels([])
          axis.set_xticks([])
          axis.set_yticklabels([])
          axis.set_yticks([])
        gridCol += 1

      gridRow += 1
      gridCol = 0
    assert total_rows_summarized <= self.dataset.getNumRows()
    # add the adjancy elements
    adjList = map(lambda pair :
                  (l1Order.index(pair[0])+1, l2Order.index(pair[1])),
                  self.cm.getAdjacencyList())
    for pair in adjList:
      self.plots[pair].get_frame().set_linewidth(3)

  def drawConfusionMatrixColumnSummary(self, l2Order):
    """draw the marginal cluster summaries along the top
    """
    gridRow = 0
    gridCol = 0
    for lab2 in l2Order:
      data = nx.array(map(self.dataset.getRowData,
                               self.labeling2.getRowsByLabel (lab2)))
      if nx.shape(data) != (0,):
        axis = self.addPlot(gridRow, gridCol)
        plot = self.drawMarginalPlot(axis, data, lab2)
      gridCol +=1 

  def drawConfusionMatrixRowSummary(self, l1Order):
    """draw cluster summaries along the side
    """
    gridRow = 1
    gridCol = self.numCols - 1
    for lab1 in l1Order:
      data = nx.array(map(self.dataset.getRowData,
                               self.labeling1.getRowsByLabel (lab1)))
      if nx.shape(data) != (0,):
        axis = self.addPlot(gridRow, gridCol)
        plot = self.drawMarginalPlot(axis, data, lab1)
      gridRow +=1 
    
  def drawMarginalPlot(self, axis, data, label):

    """
    drawMarginalPlot(self, data)

    draws the marginal plots for the confusion matrix summary.
    """
    mean = mlab.mean(data)
    posStd = tuple(mean + safeStdDev(data))
    negStd = tuple(mean - safeStdDev(data))
    mean = tuple(mean)
    xdata = tuple(range(len(mean)))

    axis.plot(xdata, mean, 'b', label='mean')
    if posStd != mean:
      axis.plot(xdata, posStd, 'r-', label='std_pos')
      axis.plot(xdata, negStd, 'r-', label="std_neg")

    #turn off labels
    axis.set_xticklabels([])
    axis.set_xticks([])
    axis.set_yticklabels([])
    axis.set_yticks([])
    self.annotations[axis] = (str(len(data)), str(label))
    return axis
    
  def getColorByColumnMarginals(self):
    """Color each cell of the confusion matrix by how many members
    are in it relitive to the indicated direction """
   # now to get the ordered confusion matrix counts.
    originalMatrix = nx.array(self.cm.getMatrix())
    m = []
    rowCount = 0
    rowOrder, colOrder = self.cm.getClusterOrders() 
    for row in self.l1Order:
      m.append([])
      for col in self.l2Order:
        m[rowCount].append(originalMatrix[rowOrder.index(row),
                                          colOrder.index(col)])
      rowCount+=1

    # normalize the matrix.
    matrix = nx.array(m)
    matrix = matrix.astype(nx.Float)
    matrix = matrix / (nx.sum(matrix)+.0001)   # solves the problem of occasionaly dividing by zero
    del(m)
    del(originalMatrix)
   
    # now draw the pretty confusion matrix.
    # null list is in there because the first row is a header
    color_matrix = [[]]
    matrixPos = [0,0] # keep track of position in the matrix 
    for lab1 in self.l1Order:
      color_row = []
      for lab2 in self.l2Order:
        color = colorsys.hsv_to_rgb(matrix[matrixPos]*.7, 1, 1)
        color_row.append(color)
        matrixPos[1] +=1
      color_matrix.append(color_row)
      matrixPos[0] +=1
      matrixPos[1] = 0
    return color_matrix

  def getColorByRowMarginals(self):
    """Compute the color matrix based on the Row Margins
    """
    # now to get the ordered confusion matrix counts.
    originalMatrix = nx.array(self.cm.getMatrix())
    m = []
    rowCount = 0
    rowOrder, colOrder = self.cm.getClusterOrders() 
    for row in self.l1Order:
      m.append([])
      for col in self.l2Order:
        m[rowCount].append(originalMatrix[rowOrder.index(row),
                                          colOrder.index(col)])
      rowCount+=1

    # normalize the matrix.
    matrix = nx.array(m)
    matrix = matrix.astype(nx.Float)
    matrix = nx.transpose(nx.transpose(matrix) / (nx.sum(matrix,1)+.0001))   # solves the problem of occasionaly dividing by zero
    del(m)
    del(originalMatrix)
   
    # now draw the pretty confusion matrix.
    # null list is in there because the first row is a header
    color_matrix = [[]] 
    matrixPos = [0,0] # keep track of position in the matrix 
    for lab1 in self.l1Order:
      color_row = []
      for lab2 in self.l2Order:
        color = colorsys.hsv_to_rgb(matrix[matrixPos]*.7, 1, 1)
        color_row.append(color)
        matrixPos[1] +=1
      color_matrix.append(color_row)
      matrixPos[0] +=1
      matrixPos[1] = 0
    return color_matrix
  
  def sortLabelingsByClusterSize(self, l1Order, l2Order):
    """Sort the labelings by cluster sizes
    """
    # order the data by cluster sizes
    if l1Order is None:
      l1RowLabels = self.labeling1.getLabelByRows()
      l1Order = map(lambda x: (l1RowLabels.count(x), x), self.labeling1.getLabels())
      l1Order.sort()
      l1Order = map(lambda x: x[1], l1Order)
      l1Order.reverse()
      self.l1Order = l1Order
    else:
      l1Order = self.l1Order
      l1Order.reverse()

    if self.l2Order is None:
      l2RowLabels = self.labeling2.getLabelByRows()
      l2Order = map(lambda x: (l2RowLabels.count(x), x), self.labeling2.getLabels())
      l2Order.sort()
      l2Order = map(lambda x: x[1], l2Order)
      self.l2Order = l2Order
    else:
      l2Order = self.l2Order
    return (l1Order, l2Order)
      
  def colorByInterclusterROC(self):
    """ Color each cell by the pairwise ROC area """
    
    gridPos = [1,0]
    for lab1 in self.l1Order:
      for lab2 in self.l2Order:
        RowRocValue = roc.interclusterROC(self.dataset, 
                                           self.labeling1, lab1, 
                                           self.labeling2, lab2)
        ColRocValue = roc.interclusterROC(self.dataset,
                                           self.labeling2, lab2,
                                           self.labeling1, lab1)
        print "%s vs %s ROC Area = %3.2f/%3.2f"%(str(lab1), 
                                                 str(lab2), 
                                                 RowRocValue[0],
                                                 ColRocValue[0],)
        rowColor = rgbToString(colorsys.hsv_to_rgb(RowRocValue[0]*.7, 1, 1))
        colColor = rgbToString(colorsys.hsv_to_rgb(ColRocValue[0]*.7, 1, 1))
        if self.plots.has_key(tuple(gridPos)):
          plot = self.plots[tuple(gridPos)]
          if not plot.marker_exists('rocUpper'):
            plot.marker_create('polygon', name='rocUpper')
          if not plot.marker_exists('rocLower'):
            plot.marker_create('polygon', name='rocLower')
          xmin, xmax = plot.xaxis_limits()
          ymin, ymax = plot.yaxis_limits()
          plot.marker_configure('rocUpper',
                                coords = (xmin,ymin, xmax, ymax, xmin, ymax, xmin, ymin),
                                fill = colColor,
                                under = 1)
          plot.marker_configure('rocLower',
                                coords = (xmax, ymin, xmin,ymin, xmax, ymax, xmax, ymin),
                                fill = rowColor,
                                under = 1)
          plot.configure(plotbackground='white') 
        gridPos[1]+=1
      gridPos[0]+=1
      gridPos[1] =0

  def drawSummaryPlot(self, axis, row, computeROC=1, titles=1, verbose=0):
    rows_summarized = PlotPage.drawSummaryPlot(self, axis, row, computeROC, titles, verbose)
    self.annotations[axis] = str(len(self.means._mapKeysToParent([row])))
    for line in axis.get_lines():
      line.set_color('k')
      if line.get_label() != 'mean': 
        line.set_linestyle(':')
    return rows_summarized
#    g.bind('<Button-2>',  lambda event, row=row:self.fullPlot(row)) 
#    g.configure(plotbackground = color)
#    g.yaxis_configure(hide=1)
#    g.xaxis_configure(hide=1)
#    g.element_configure('mean', color='black', pixels='0')
#    if 'std1'  in g.element_show():
#      g.element_configure('std1', hide=0 , color='black', pixels='0', dashes = (1,1) )
#      g.element_configure('std2', hide=0 , color='black', pixels='0', dashes = (1,1) )
#
#    # make a title -
#    g.marker_create('text', name='size')
#    g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]),
#                       text=str(len(self.means._mapKeysToParent([row]))),
#                       anchor='w')
#    
#    return(g)
#
#  def setUniformYAxis(self, min=None, max=None):
#    PlotPage.setUniformYAxis(self, min, max)
#    for g in self.plots.values(): 
#      try:
#        g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]))
#      except:
#        pass
#      try:
#        g.marker_configure('name', coords = (g.xaxis_limits()[1], g.yaxis_limits()[0]))
#      except:
#        pass
#
#  def setOptimalYAxis(self):
#    PlotPage.setOptimalYAxis(self)
#    for g in self.plots.values(): 
#      try:
#        g.marker_configure('size', coords = (g.xaxis_limits()[0], g.yaxis_limits()[1]))
#      except:
#        pass
#      try:
#        g.marker_configure('name', coords = (g.xaxis_limits()[1], g.yaxis_limits()[0]))
#      except:
#        pass
#

