"""PlotPage provides a way of holding a collection of subplots.
"""
import matplotlib
from matplotlib.numerix import mlab
from matplotlib.axes import Subplot, Axes
from matplotlib.figure import Figure
from matplotlib.transforms import Bbox
from matplotlib.transforms import lbwh_to_bbox

from compClust.score import roc
from compClust.util import NaN
from compClust.mlx.views import AggregateFunctionView
from compClust.mlx.labelings import GlobalWrapper
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import Labeling
from compClust.mlx.labelings import subsetByLabeling
from compClust.mlx.labelings import Labeling
from compClust.util.InterpreterTools import safeStdDev

import Plot
from views import DatasetRowPlotView

__docformat__ = "restructuredtext en"

class PlotPage(Plot.Plot):
  """Contain a list of plots

  """
  # FIXME: should there be some way of controling if we want a
  # FIXME: scrollable window? The old non-matplotlib one defaulted to
  # FIXME: scrollable

  def __init__(self, numRows, numCols=3, canvasFactory=None, figsize=None, figscale=1, minsize=None):
    """
    PlotPage allows one to render multiple plots in a single graph.
    
    Unfortuantely, The layout needs to be preallocated so PlotPage
    requires that at least the number of rows are specified.
    It defaults to 3 columns, though you can override that if necessary.

    :Parameters:
      - `numRows`: The number of rows in the plot
      - `numCols`: The number of columns in the plot, defaults to 3
      - `canvasFactory`: factory object for the different matplotlib backends
      - `figsize`: A tuple, specifying the size of the final plot. If None,
                   the figure size defaults to numRows * figscale by
                   numCols *figsize. 
      - `figscale`: Used to compute the figsize if figsize is None
      - `minsize`: The minimum size a figure should be, defaults to (6,6)
    """
    if minsize is None:
      minsize = (6,6)
    self.numRows = numRows
    self.numCols = numCols
    # try to come up with a reasonable figure size for plots
    if figsize is None:
      figsize = (numCols * figscale, numRows * figscale)
    Plot.Plot.__init__(self, canvasFactory=canvasFactory, figsize=figsize, minsize=minsize)
    self.plots = {}
    self.axisToLabel = {}
    
  def createMeanStdDataAndLabelings(self, labeling):
    """Construct the mean and standard deviation RowAggregateFunctionView
    and labelings (including type conversion of the labeling if necessary.)
    """
    # pre-initialize the summary views
    # FIXME: so using IDs might not be the best idea as if the dataset is reloaded the view will be broken
    means_name = "means-%s" % (id(labeling))
    std_name = "stddev-%s" % (id(labeling))
    keylists =  [self.labeling.getRowsByLabel(label) for label in self.labeling.getLabels()]
    
    self.means = self.dataset.addViewDefault(means_name, AggregateFunctionView, self.dataset, keylists, mlab.mean)
    
    self.stds  = self.dataset.addViewDefault(std_name, AggregateFunctionView, self.dataset, keylists, safeStdDev)
    
    #self.means = self.dataset.addViewDefault(means_name, RowAggregateFunctionView, self.dataset, labeling, mlab.mean)
    #self.stds = self.dataset.addViewDefault(std_name, RowAggregateFunctionView, self.dataset, labeling, safeStdDev)

    if isinstance(labeling, GlobalWrapper):
      self.meansLab = GlobalWrapper(self.means, glabeling=labeling.g)
      self.stdLab = GlobalWrapper(self.stds, glabeling=labeling.g)

    elif isinstance(labeling, GlobalLabeling):
      self.meansLab= GlobalWrapper(self.means, glabeling=labeling)
      self.stdsLab = GlobalWrapper(self.stds,  glabeling=labeling)

    elif isinstance(labeling, Labeling):
      self.meansLab = Labeling(self.means)
      self.stdsLab  = Labeling(self.stds)
      self.meansLab.labelFrom(labeling)
      self.stdsLab.labelFrom(labeling)
    else:
      raise ValueError(type(labeling), "is not a Labeling")

  def addPlot(self, row, col, **kwargs):
    """Add a subplot on this plot page to a specific location

    will pass on extra keyword arguments to the underlying rendering code
    """
    index = self.getIndex(row, col)
    subplot = self.figure.add_subplot(self.numRows, self.numCols, index, **kwargs)
    self.plots[(row, col)]=subplot
    return subplot

  def getPlot(self, row, col):
    """Return the subplot at the specified indicies
    """
    return self.plots.get((row, col), None)

  def getIndex(self, row, col):
    """Given a row, col index return what the matplotlib plot index should be
    """
    return self.numCols * row + col + 1
  
  def drawSummaryPlot(self, axis, row, computeROC=1, titles=1, verbose=0):
    """
    drawSummaryPlot(self, row, roc=1)

    Draws the aggregate summary tragetory plot for the given row in
    the self.means dataset
    
    `Return`: rows sumarized.
    """
    # set up the basic cluster info
    label = self.meansLab.getLabelsByRow(row)[0]
    rows_summarized = self.labeling.getRowsByLabel(label)
    size = len(rows_summarized)
    if verbose:
      print "working on %s"%(str(label))
    mean = self.means.getRowData(row)
    posStd = tuple(mean + self.stds.getRowData(row))
    negStd = tuple(mean - self.stds.getRowData(row))
    mean = tuple(mean)
    xdata = tuple(range(len(mean)))
    if computeROC:
      rocArea = roc.clusterROC(self.dataset, self.labeling, label)[0]
      if verbose:
        print "ROC area = %3.2f"%(rocArea)
    else:
      rocArea = NaN.nan
    
    # draw the lines (don't need a fancy dataset plotter for these three lines
    if verbose:
        print "\t Drawing Plot"
    axis.plot(xdata, mean, 'b', label='mean')
    if posStd != mean:
      axis.plot(xdata, posStd, 'r-', label='std_pos')
      axis.plot(xdata, negStd, 'r-', label="std_neg") 
      
    if titles:
      #title = '<Cluster %s>\n  Count:%i, ROC: %3.2f'%(label, size, rocArea)
      title = '<Cluster %s>'%(label)
      axis.set_title(title, fontdict=self.title_font)
      # try embedding plot stats
      xmin, xmax = axis.viewLim.intervalx().get_bounds()
      ymin, ymax = axis.viewLim.intervaly().get_bounds()
      axis.text(xmin, ymax, str(size), verticalalignment='top')
      if computeROC:
        axis.text(xmax, ymin, "roc: %3.2f"%(rocArea), horizontalalignment = 'right')

    self.axisToLabel[axis] = label
    return rows_summarized

  def onClick(self, event):
    """Create full iplot window for whichever plot the user selected.
    """
    # get the x and y coords, flip y from top to bottom
    print event.button #, event.state, event.type
    #height = self.canvas.figure.bbox.height()
    #x, y = event.x, height-event.y
    x, y = event.x, event.y
    if event.button==1:
      for ax in self.canvas.figure.axes:
        if ax.in_axes(x, y):
          label = self.axisToLabel.get(ax, None)
          print "done show axis: ", ax,
          if label is not None:
            self.fullPlot(label)
    
  def fullPlot(self, label):
    """
    fullPlot(event, row)

    show the full row-wise plot of the cluster pointed to in the mean
    dataset at row
    """
    print "printing %s"%(str(label))
    subset = subsetByLabeling(self.dataset, self.labeling, label)
    dp = self.canvasFactory.getIPlot()
    dp.setTitle("Cluster %s\n %i elements"%(str(label),subset.getNumRows()))
    rv = DatasetRowPlotView(subset, primaryLabeling=self.primaryLabeling, 
                                    secondaryLabeling=self.secondaryLabeling)
    rv.getColorMapper().setColorByColValue(0)
    dp.plot(rv)
    self.selectedPlot = dp
    self.selectedView = rv
    self.dataset.removeView(subset)

  def setUniformYAxis(self, min=None, max=None):

    """
    setUniformYAxis(self)

    Finds the maximal range spanned by any plot on the plot page and
    sets the y-axis limists on all plots to its max and min

    """
    if (max is None) or (min is None):
      minY = 0
      maxY = 0
      # find the max and min
      for plotKey in self.plots.keys():
        yrange = self.plots[plotKey].yaxis.get_major_locator().autoscale()
        self.xrange = self.plots[plotKey].xaxis.get_major_locator().autoscale()
        ymin, ymax = yrange

        if yrange[0] < minY:
          minY = yrange[0]
        if yrange[1] > maxY:
          maxY = yrange[1]
    else:
      maxY = max
      minY = min
    self.yrange = (minY, maxY)

    # FIXME: need to handle scaling
    for plotKey in self.plots.keys():
      self.plots[plotKey].viewLim.intervaly().set_bounds(minY, maxY)

  def setOptimalYAxis(self):

    """
    setOptimalAxis(self)

    sets the axis such that all of the data is displayed.
    """
    
    for plotKey in self.plots.keys():
      self.plots[plotKey].axis_configure('y', min='', max='')
    
  def toggleZeros(self):

    """
    showZeros(self)

    toggles if each plot displays (0,0) lines.
    """

    for plotKey in self.plots.keys():
      p = self.plots[plotKey]
      xmin,xmax = p.xaxis_limits()
      ymin,ymax = p.yaxis_limits()
      if p.marker_exists ('xzero'):
        p.marker_delete('xzero')
        p.marker_delete('yzero')
      else:
        p.marker_create('line', name='xzero', coords=(xmin, 0, xmax, 0))
        p.marker_create('line', name='yzero', coords=(0, ymin, 0, ymin))
