import matplotlib.numerix as nx
from matplotlib.axes import Subplot

from PlotPage import PlotPage
from IPlot import IPlot, DatasetPlot
from views import DatasetRowPlotView

from compClust.mlx.labelings import Labeling
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import GlobalWrapper

from compClust.score.ConfusionMatrix2 import ConfusionMatrix

from compClust.util.InterpreterTools import safeStdDev

class TrajectorySummary(PlotPage):

  """
  A fairly simple summary of a labeling.
  """

  def __init__(self, canvasFactory, dataset, clusterLabeling, primaryLabeling=None, secondaryLabeling=None, computeROC=1):
    """
    __init__(self, dataset, clusterLabeling, primaryLabeling=None, secondaryLabeling=None, computeROC=1, parent=None):

    the primary and secondaryLabelings are assumed to be globalLabelings.
    
    """
    # some resonable hard-coded plotting parameters.    
    numCols = 3

    # sets up the needed information
    self.dataset = dataset
    self.labeling = clusterLabeling
    self.primaryLabeling= primaryLabeling
    self.secondaryLabeling = secondaryLabeling

    self.createMeanStdDataAndLabelings(self.labeling)

    # construct base class
    clusterSizes = [ (len(self.labeling.getRowsByLabel(x)), x)
                     for x in self.labeling.getLabels() ]
    numRows =int(nx.ceil(len(clusterSizes)/float(numCols)))
    PlotPage.__init__(self, numRows, numCols, canvasFactory=canvasFactory)

    # fill the plot
    self.drawMeanStdTrajectories(computeROC)
    self.setUniformYAxis()
    self.show()

    # these are pointers to the full plot DatasetPlotter and its PlotView
    self.selectedPlot = None
    self.selectedView = None


  def drawMeanStdTrajectories(self, computeROC=1):
    """Lay out all subplots sorted in order of cluster size
    """
    clusterSizes = [ (len(self.labeling.getRowsByLabel(x)), x) for x in self.labeling.getLabels() ]
    clusterSizes.sort()
    clusterSizes = [ x for x in clusterSizes if x > 0 ]
    plotOrder = [ self.meansLab.getRowsByLabel(label)[0]
                    for label in [ x[1] for x in clusterSizes]]

    gridPos = [0,0]
    count = 1
    for row in plotOrder:
      axis = self.figure.add_subplot(self.numRows,self.numCols,count)
      # FIXME: Something needs to handle subplots under matplotlib 0.82
      #sub = Subplot(self.figure, self.numRows*100+self.numCols*10+count)
      #axis = self.figure.add_subplot(sub)
      self.drawSummaryPlot(axis, row, computeROC)
      
      # the tests below will change how the tick labels show up, 
      # The True or forces them to be always missing, I'm leaving
      # the rest of the code in to make it simpler to switch back to 
      # the only put the ticks on the edges of the plot.
      
      # The tick labels get really busy, so only print on the left side and bottom
      # count is 1 based, so we have to -1 to find the sequence 0,3,6,9,
      # (aka the first col)
      if True or (count-1) % self.numCols != 0:
        axis.set_yticklabels([])
        axis.set_yticks([])
      # theoretically once we exceed the number of cells-numCols we should be
      # at the bottom row of the plot, even if its jagged
      if True or count <= len(plotOrder)-self.numCols:
        axis.set_xticklabels([])
        axis.set_xticks([])
        
      count +=1
      


  def setZoom(self, zoom):
    """
    resizes all the plots on the page by the propotation specified.
    """
    raise NotImplemented("Need to fix zoom")
#     if zoom > 10:
#       print "Choose a smaller zoom"
#       return
#     newSize = (self.plotSize[0]*zoom, self.plotSize[1]*zoom)
#     for plotKey in self.plots.keys():
#      self.plots[plotKey].configure(width=newSize[0],height=newSize[1])

class ScatterMatrix(PlotPage):

  """
  Generates a plot page with a scatter plot between each dimension
  in a dataset.
  """

  def __init__(self, dataset, dims = None, primaryLabeling=None, secondaryLabeling=None, parent=None):

    if dims is None:
      dims = range(dataset.getNumCols())

    PlotPage.__init__(self, parent, numRows=len(dims), numCols =len(dims))
    self.dataset = dataset
    self.__plotView = DatasetRowPlotView(self.dataset, primaryLabeling, secondaryLabeling)
    self.__colorMapper = self.__plotView.getColorMapper()
    self.__dataMapper = ColumnScatterDataMapper(self.__plotView)
    self.__plotView.setDataMapper(self.__dataMapper)
    self.__makeScatterPlots(dims)
    
  def __makeScatterPlots(self, dims=None):
    
    if dims is None:
      dims = range(self.dataset.getNumCols())

    for x in dims:
      for y in dims:
        if self.plots.has_key((x,y)):
          p = self.plots[(x,y)]
        else:
          p = self.addPlot(x,y)
        self.__dataMapper.setXColumn(x)
        self.__dataMapper.setYColumn(y)
        p.plot(self.__plotView, pack=0)
        p.configure(width=self.plotSize[0],height=self.plotSize[1])
        p.legend_configure(hide=1)
        p.yaxis_configure(hide=1)
        p.xaxis_configure(hide=1)

  def getColorMapper(self):
    return(self.__colorMapper)

  def getPlotView(self):
    return(self.__plotView)

  def updatePlots(self):

    self.__makeScatterPlots()
