import matplotlib.numerix as nx
from matplotlib.numerix import mlab

from compClust.mlx.views import RowPCAView

from compClust.iplot.PlotPage import PlotPage
from compClust.iplot.IPlot import DatasetPlot
from compClust.iplot.views import DatasetRowPlotView
from compClust.iplot.mappers.DataMapper import PCADataMapper


class PCAPlot:
  """Class to handle the commonality of the different types of pca views
  """
  def __init__(self, canvasFactory, dataset, primaryLabeling=None, secondaryLabeling=None, master=None):
    self.canvasFactory = canvasFactory
    self.dataset = dataset
    self.pcaDS =self.dataset.addViewDefault('RowPCAView', RowPCAView, self.dataset)
    self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling

    # create pca view
    self.pcaView = DatasetRowPlotView(self.dataset, 
                                      primaryLabeling=self.primaryLabeling, 
                                      secondaryLabeling=self.secondaryLabeling)
    
    self.pcaDataMapper = PCADataMapper(self.pcaView)
    self.pcaView.setDataMapper(self.pcaDataMapper)

  def ProjectionPlot(self, xaxis=0, yaxis=1):
    return PCAProjectionPlot(self.canvasFactory, self.pcaDS, self.pcaView)
  
  def EigenVectorPlot(self):
    return PCAEigenVectorPlot(self.canvasFactory, self.pcaDS, self.pcaView)
  
class PCAProjectionPlot(DatasetPlot):
  """
  A simple exploritory tool to better understand the signifigance of PCA
  """

  def __init__(self, canvasFactory, pcaDS, pcaView):
    DatasetPlot.__init__(self,canvasFactory=canvasFactory)

    self.pcaDS = pcaDS
    self.pcaView = pcaView
    
    self.drawProjection()

  def getEigenVariances(self):
    """Return list of how much variance is captured by each eigen vecter
    (Note that the length of this vector is the total number of eigen vectors)
    """
    variances = self.pcaDS.getVariances()
    return variances

  def setAxisSquare(self):
    """Adjust the axes so each axis covers the same range
    """
    ylim = self.axis.yaxis.get_major_locator().autoscale()
    xlim = self.axis.xaxis.get_major_locator().autoscale()

    print "ylim:", ylim
    print "xlim:", xlim
    
    yrange = ylim[1]-ylim[0]
    xrange = xlim[1]-xlim[0]

    if yrange > xrange:
      xmid = xrange/2.0 + xlim[0]
      xmin = xmid - yrange/2.0
      xmax = xmid + yrange/2.0
      self.axis.viewLim.intervalx().set_bounds(xmin, xmax)
    else:
      ymid = yrange/2.0 + ylim[0]
      ymin = ymid - xrange/2.0
      ymax = ymid + xrange/2.0
      self.axis.viewLim.intervaly().set_bounds(ymin, ymax)
      
  def drawProjection(self, xaxis=0, yaxis=1, cluster_labeling=None):
    """Draw the scatter plot comparing two eigen vectors
    """
    variances = self.pcaDS.getVariances()
    rotMatrix = self.pcaDS.matrix
    self.axis.cla()

    pcaDataMapper = self.pcaView.getDataMapper()
    pcaDataMapper.setXColumn(xaxis)
    pcaDataMapper.setYColumn(yaxis)
    
    self.scatter(self.pcaView)

    self.axis.set_title('PCA Space', fontdict=self.title_font)
    self.axis.set_xlabel('PC-%d' %(xaxis+1))
    self.axis.set_ylabel('PC-%d' %(yaxis+1))

    # FIXME: this is the start of some code to color
    # FIXME: pca plots by cluster_labeling, it does need 
    # FIXME: more testing and some glue to attach it to the guif
    # configure color mapper
    if cluster_labeling is not None:
      from mappers import RowColorMapper
      pcaColorMapper = RowColorMapper(self.pcaView)
      pcaColorMapper.setColorByLabeling(cluster_labeling)
      self.pcaView.setColorMapper(pcaColorMapper)
    
    self.setAxisSquare()
    
class PCAEigenVectorPlot(PlotPage):
  """
  A simple exploritory tool to better understand the signifigance of PCA
  """

  def __init__(self, canvasFactory, pcaDS, pcaView):
    self.pcaDS = pcaDS
    self.pcaView = pcaView

    variances = self.pcaDS.getVariances()

    self.numCols = 3
    self.numRows = (len(variances) / self.numCols) + 1
    figsize = (self.numCols*2, self.numRows * 2)
    PlotPage.__init__(self, self.numRows, self.numCols, canvasFactory=canvasFactory, figsize=figsize)

    rotMatrix = self.pcaDS.matrix

    yaxis_min = 1e99 # random constant representing +infinity
    yaxis_max = -1e99 # random constant representing -infinity
    for i in range(len(self.pcaDS.matrix)):
      row = i / self.numCols
      col = i % self.numCols
      axis = self.addPlot(row, col)

      axis.set_title('PC-%i %3.2f%%'%(i+1,variances[i]*100),
                     fontdict=self.title_font)
      axis.plot(rotMatrix[:,i])
      
      if yaxis_min > axis.get_ylim()[0]:
        yaxis_min = axis.get_ylim()[0]
      if yaxis_max < axis.get_ylim()[1]:
        yaxis_max = axis.get_ylim()[1]

    for i in range(len(self.pcaDS.matrix)):
      row = i / self.numCols
      col = i % self.numCols
      axis = self.getPlot(row, col)
      axis.set_ylim((yaxis_min, yaxis_max))

class PCAPlotOld(PlotPage):
  """
  A simple exploritory tool to better understand the signifigance of PCA
  """

  def __init__(self, canvasFactory, dataset, primaryLabeling=None, secondaryLabeling=None, master=None):
    self.pcaDS = dataset.addViewDefault('RowPCAView', RowPCAView, dataset)
    PlotPage.__init__(self, 1,1, canvasFactory=canvasFactory)

    self.scales = []
    self.buttons = []

    self.dataset = dataset
    self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling

    self.pcaPlot = None

    # create pca view
    self.v = DatasetRowPlotView(self.dataset, 
                                primaryLabeling=self.primaryLabeling, 
                                secondaryLabeling=self.secondaryLabeling)
    
    self.pcaDataMapper = PCADataMapper(self.v)
    self.v.setDataMapper(self.pcaDataMapper)

    self.drawProjection()

  def getEigenVariances(self):
    """Return list of how much variance is captured by each eigen vecter
    (Note that the length of this vector is the total number of eigen vectors)
    """
    variances = self.pcaDS.getVariances()
    return variances
  
  def drawProjection(self, xaxis=0, yaxis=1):
    """Draw the scatter plot comparing two eigen vectors
    """
    variances = self.pcaDS.getVariances()
    rotMatrix = self.pcaDS.matrix
    if self.pcaPlot is None:
      self.pcaPlot = DatasetPlot(self.canvasFactory, axis=self.addPlot(0,0))
    else:
      self.pcaPlot.axis.cla()
      
    self.pcaDataMapper.setXColumn(xaxis)
    self.pcaDataMapper.setYColumn(yaxis)
    plotData =  self.v.getDataMapper().getPlotData()
    
    xdata = [ x[0] for x in plotData]
    ydata = [ y[1] for y in plotData]

    collection = self.pcaPlot.axis.scatter(xdata, ydata)
    collection._labels = self.primaryLabeling

    self.pcaPlot.axis.set_title('PCA Space', fontdict=self.title_font)
    self.pcaPlot.axis.set_xlabel('PC-%d' %(xaxis))
    self.pcaPlot.axis.set_ylabel('PC-%d' %(yaxis))


  def drawNativePlot(self):
    rotMatrix = self.pcaDS.matrix

    self.nativePlot = self.addPlot(0,0)
    xdata=tuple(range(len(rotMatrix)))
    ydata=tuple([0]*len(rotMatrix))
    print xdata
    print ydata
    self.nativePlot.plot('point', (xdata, ydata))
    self.nativePlot.set_title('Native Space')
    #self.nativePlot.legend_configure(hide=1)
    #self.nativePlot.pack(expand=1, fill='both')
   
  def drawEigenPlots(self):
    variances = self.pcaDS.getVariances()
    rotMatrix = self.pcaDS.matrix

    for i in range(len(self.pcaDS.matrix)):
      b = Tkinter.Radiobutton(self.pp.frame)
      self.buttons.append(b)
      self.pp.addWidget(b, row=i, col=0)
      g = plot(rotMatrix[:,i], pack=0, master=self.pp.frame)
      g.configure(title='PC-%i %3.2f%% of variance captured'%(i, variances[i]*100),
                  height=200, width=400)
      g.legend_configure(hide=1)
      self.pp.addWidget(g, row=i, col=1)
      s = Tkinter.Scale(self.pp.frame, 
                        from_= min([0, mlab.min(self.pcaDS.getColData(i))]),
                        to = max([mlab.max(self.pcaDS.getColData(i))]),
                        resolution = .1,
                        command = lambda value, component=i: self.sliderUpdate(value,component))
      self.scales.append(s)
      self.pp.addWidget(s, row=i, col=2)

#####
# This code is sitting in here as I started to port the BLT pca explorer
# to matplotlib, however there's no really good toolkit independent way of
# handling the various user interface options. so it is probably doomed
# to failure. But i'm going to leave it in here for the moment as there
# might be a way of getting it to work in the future.
class PCAExplorer:
  """
  A simple exploritory tool to better understand the signifigance of PCA
  """
  def __init__(self, canvasFactory, dataset, primaryLabeling=None, secondaryLabeling=None, parent=None):
    self.canvasFactory = canvasFactory
    self.dataset = dataset
    self.pcaDS =self.dataset.addViewDefault('RowPCAView', RowPCAView, self.dataset)
    self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling

    
    self.scales = []
    self.buttons = []
    variances = self.pcaDS.getVariances()
    rotMatrix = self.pcaDS.matrix
    self.pp = PlotPage(numRows=len(variances), numCols=1, canvasFactory=self.canvasFactory)

    self.nativePlot = DatasetPlot(canvasFactory=self.canvasFactory)
    #self.nativePlot.line_create('point', ydata=tuple([0]*len(rotMatrix)),
    #                                     xdata=tuple(range(len(rotMatrix)))
    #                                     )
    self.nativePlot.title('Native Space')
    #self.nativePlot.legend_configure(hide=1)
    #self.nativePlot.pack(expand=1, fill='both')
   
    self.pcaView = DatasetRowPlotView(self.dataset, 
                                      primaryLabeling=self.primaryLabeling, 
                                      secondaryLabeling=self.secondaryLabeling)
    self.pcaDataMapper = PCADataMapper(self.pcaView)
    self.pcaView.setDataMapper(self.pcaDataMapper)
    # self.pcaFrame = Tkinter.Toplevel()
    self.pcaFrame = self.canvasFactory.getCanvas()
    
    pcaDimLabels = ['PC-%i (%3.2f%%)'%(i,v) for i,v in zip(range(len(variances)), variances)]
    #self.pcaControlFrame = Tkinter.Frame(self.pcaFrame)
    self.pcaControlFrame = self.canvasFactory.getCanvas()
    
    self.pcaXChooser = Pmw.OptionMenu(self.pcaControlFrame, 
                                      labelpos = 'w',
                                      label_text = 'X-Axis:',
                                      command=self.__xChooserCB,
                                      items=pcaDimLabels)
    self.pcaXChooser.pack(anchor='s', side='left', expand=1, fill='x')
    self.pcaYChooser = Pmw.OptionMenu(self.pcaControlFrame, 
                                      labelpos = 'w',
                                      label_text = 'Y-Axis:',
                                      command=self.__yChooserCB,
                                      items=pcaDimLabels)
    self.pcaYChooser.pack(anchor='s', side='left', expand=1, fill='x')

    self.pcaPlot = DatasetPlot(parent=self.pcaFrame)
    self.pcaPlot.pack(side='top', anchor='w', expand=1, fill='both')
    self.pcaControlFrame.pack(side='top', anchor='w', expand=1, fill='x')
    self.pcaPlot.line_create('point', ydata=(0,), xdata=(0,), symbol='cross', pixels='.22i')
    self.pcaPlot.plot(self.pcaView)
    #self.pcaPlot.element_activate('point')
    #self.pcaPlot.legend_configure(hide=1)
    #self.pcaPlot.configure(title='PCA Space')
    #self.pcaPlot.xaxis_configure(title='PC-1')
    #self.pcaPlot.yaxis_configure(title='PC-2')
    #self.pcaPlot.bind('<ButtonRelease-2>', self.plotUpdate)

    for i in range(len(self.pcaDS.matrix)):
      b = Tkinter.Radiobutton(self.pp.frame)
      self.buttons.append(b)
      self.pp.addWidget(b, row=i, col=0)
      g = plot(rotMatrix[:,i], pack=0, parent=self.pp.frame)
      #g.configure(title='PC-%i %3.2f%% of variance captured'%(i, variances[i]*100),
      #            height=200, width=400)
      #g.legend_configure(hide=1)
      self.pp.addWidget(g, row=i, col=1)
      s = Tkinter.Scale(self.pp.frame, 
                        from_= min([0, mlab.min(self.pcaDS.getColData(i))]),
                        to = max([mlab.max(self.pcaDS.getColData(i))]),
                        resolution = .1,
                        command = lambda value, component=i: self.sliderUpdate(value,component))
      self.scales.append(s)
      self.pp.addWidget(s, row=i, col=2)
    
  def sliderUpdate(self, value, component):
    """Update from sliders plot """

    vector = Numeric.dot(Numeric.array([s.get() for s in self.scales ]),
                         Numeric.transpose(self.pcaDS.matrix))
    #self.nativePlot.element_configure('point', ydata=tuple(vector))
    
    xaxis = self.pcaDataMapper.getXColumn()
    yaxis = self.pcaDataMapper.getYColumn()
    self.pcaPlot.element_configure('point', xdata=(self.scales[xaxis].get(),), ydata=(self.scales[yaxis].get(),))
                                            
  def plotUpdate(self, event):
    """update from pca plot"""
    x,y = self.pcaPlot.invtransform (event.x, event.y)
    #self.pcaPlot.element_configure('point', xdata=(x,), ydata=(y,))     
 
    ## FIX ME TO DEAL WITH PCA PLOTS OF DIMENSIONS OTHER THAN 1 AND 2
    xaxis = self.pcaDataMapper.getXColumn()
    yaxis = self.pcaDataMapper.getYColumn()
    
    vector = Numeric.dot(Numeric.transpose(Numeric.take(self.pcaDS.matrix, [xaxis, yaxis])), 
                         Numeric.array([x,y]))
    self.nativePlot.element_configure('point', ydata=tuple(vector))
    self.scales[xaxis].set(x)
    self.scales[yaxis].set(y)
    for i in range(len(self.scales)):
      if i not in [xaxis, yaxis]:
        self.scales[i].set(0)
  
  def __xChooserCB(self, dimString):
    self.pcaDataMapper.setXColumn(int(dimString[3]))
    self.pcaPlot.xaxis_configure(title=dimString)
    self.pcaPlot.plot()
  def __yChooserCB(self, dimString):
    self.pcaPlot.yaxis_configure(title=dimString)
    self.pcaDataMapper.setYColumn(int(dimString[3]))
    self.pcaPlot.plot()
  
