import colorsys
import operator
import types

import matplotlib
import matplotlib.numerix as nx
from matplotlib import collections
from matplotlib.axes import Subplot, Axes
from matplotlib.transforms import Bbox
from matplotlib.transforms import lbwh_to_bbox

import Plot

class IPlot(Plot.Plot):
  def __init__(self, canvasFactory, axis=None):
    Plot.Plot.__init__(self, canvasFactory=canvasFactory, axis=axis)
    if axis is None:
      self.axis = self.figure.add_subplot(1,1,1)
    else:
      self.axis = axis

    self.show_axes = True
    self.show_grid = False
    self.show_legend = False
    self.xlogscale = False
    self.ylogscale = False
    self.descending = False

  def onClick(self, event):
    # get the x and y coords, flip y from top to bottom
    #height = self.canvas.figure.bbox.height()
    #x, y = event.x, height-event.y
    x, y = event.x, event.y
    if event.button==1:
      for ax in self.canvas.figure.axes:
        if ax.in_axes(x, y):
          label, point = self.findNearestPoint(x,y)
          self.displayValue(label, point)

  def displayValue(self, label, point):
    """
    display an element's name, and its coordinates in data space

    Formely showed the the current value...

    """
#     print "iplot", label, point
    if label is None or point is None:
      self.axis.table([[""]], cellLoc='left', loc='top')
    else:
      x,y = point
      status = "%s: (%3.2f, %3.2f)"%(label, x, y)
      print status
      self.axis.table([[status]], cellLoc='left', loc='top')
    self.show()

  # The next functions configure the axes
  def showAxis(self, *args):
    self.show_axes = not self.show_axes
    if self.show_axes:
      self.axis.set_axis_on()
    else:
      self.axis.set_axis_off()
    
  def xlogScale(self, *args):
    raise NotImplementedError("Can't toggle log state")
    
  def ylogScale(self, *args):
    raise NotImplementedError("Can't toggle log state")
    
  def descending(self, *args):
    raise NotImplementedError("Can't toggle descending")

  # The next functions configures the Grid
  def showGrid(self, *args):
    self.show_grid = not self.show_grid
    self.axis.grid(self.show_grid)
    
  # The next functions configures the Legend
  def showLegend(self, *args):
    self.show_legend = not self.show_legend
    if self.show_legend:
      self.axis.legened()
    else:
      raise NotImplementedError("Can't disable legend")

  # general canvas configuration
  def setTitle(self, title):
    """Set title for the axis contained within this iplot instance
    """
    self.axis.set_title(title)

#########
#
# The back end plotter
#
####

class DatasetPlot(IPlot):

  """
  DatasetPlot class provides a common plotting tool for a variety
  of 2d plotting tasks.  PlotViews provide an extensible way to
  manipulate and control how data is rendered.

  """

  def __init__(self, canvasFactory, plotView=None, axis=None):
    self.currentPlotViews = []
    IPlot.__init__(self, canvasFactory, axis)
    if plotView is not None:
      self.plot(plotView)
      
  def __get_primary(self):
    """Return primary labeling"""
    if len(self.currentPlotViews) != 1:
      raise ValueError("wrong number of plot views %d" % (
                       len(self.currentPlotViews)))
    pv = self.currentPlotViews[0]
    annotationmap = pv.getAnnotationMapper()
    annotationmap.getPrimaryLabeling()
  primaryLabeling = property(__get_primary, doc="return primary labeling")

  def __get_secondary(self):
    """return secondary labeling"""
    if len(self.currentPlotViews) != 1:
      raise ValueError("wrong number of plot views %d" % (
                       len(self.currentPlotViews)))
    pv = self.currentPlotViews[0]
    annotationmap = pv.getAnnotationMapper()
    annotationmap.getPrimaryLabeling()
  secondaryLabeling = property(__get_primary, doc="return secondary labeling")
    
  def plot(self, plotView=None, hold=0, pack=1):

    """
    This uesed the defined plotView interface to plot the data
    """

    replot = 0
    if plotView is None:
      plotView = self.currentPlotViews[0]
      replot = 1
      self.axis.clear()
    
    if hold==1:
      self.currentPlotViews.append(plotView)
    else:
      self.currentPlotViews = [plotView]
      replot=1
    plotView._setPlot(self)
    plotView.plotSetup()
    plotArgs = {}

    # get all the BLT direct plot args
    if type(plotView.getKWArgs()) == types.DictionaryType:
      plotArgs.update(plotView.getKWArgs())


    plotViewIndex = self.currentPlotViews.index(plotView)
    # get All the mappers:
    dataMapper   = plotView.getDataMapper()
    colorMapper  = plotView.getColorMapper()
    markersMapper = plotView.getMarkersMapper()
    bindingsMapper   = plotView.getBindingsMapper()
    annotationMapper = plotView.getAnnotationMapper()

    # extract the info first from the Plotview do it once vs. per
    # element to make things faster...  also its all done with
    # references so little memory overhead

    plotData = dataMapper.getPlotData()
    trace    = dataMapper.getLineTrace()
    weights  = dataMapper.getWeightsAndSytles()
    colors   = colorMapper.getColors('rgb')
    sizes    = markersMapper.getMarkerSizes()
    markers  = markersMapper.getPlotMarkers()
    
    primaryLabeling = annotationMapper.getPrimaryLabeling()
    markers  = markersMapper.getPlotMarkers()
    

    if primaryLabeling is None:
      rowNames = None
    else:
      rowNames = primaryLabeling.getLabelByRows()

    if type(sizes) not in [types.ListType, nx.ArrayType, types.TupleType]:
      sizes  = [sizes] *len (plotData)
    #sizes = map(operator.add, map(str, sizes), ['i']*len(plotData))

    if type(markers) not in [types.ListType, nx.ArrayType, types.TupleType]:
      markers  = [markers] *len (plotData)

    # start plotting... 
    numCols = len(plotData[0])
    numRows = len(plotData)

    for row in range(numRows):
      xdata = plotData[row][0]
      ydata = plotData[row][1]
      # FIXME: it would be a nice performance improvement to swap the
      # FIXME: color model at some other point.
      if colors is not None:
        plotArgs['color'] = colors[row]
      if sizes is not None: plotArgs['markersize'] = sizes [row]
      if markers is not None: plotArgs['marker'] = markers [row]
      if rowNames is None:
        name = str(row)
      else:
        name = str(rowNames[row])
      plotArgs['label'] = name

      if weights is not None:
        plotArgs['weights'] = weights[row][0]
        plotArgs['styles']  = weights[row][1]

      if trace is not None:
        plotArgs['trace'] = trace
      # draw the line
      l = apply(self.axis.plot, (xdata, ydata, ), plotArgs)

    #bindingsMapper.addBindings()

    plotView.plotFinishings()
    self.currentPlotView = [plotView]
    self.show()


  def scatter(self, plotView=None, hold=0, pack=1):

    """
    This uesed the defined plotView interface to plot the data
    """

    replot = 0
    if plotView is None:
      plotView = self.currentPlotViews[0]
      replot = 1
    
    if hold==1:
      self.currentPlotViews.append(plotView)
    else:
      self.currentPlotViews = [plotView]
      replot=1
    plotView._setPlot(self)
    plotView.plotSetup()
    plotArgs = {}

    # get all the BLT direct plot args
    if type(plotView.getKWArgs()) == types.DictionaryType:
      plotArgs.update(plotView.getKWArgs())


    plotViewIndex = self.currentPlotViews.index(plotView)
    # get All the mappers:
    dataMapper   = plotView.getDataMapper()
    colorMapper  = plotView.getColorMapper()
    markersMapper = plotView.getMarkersMapper()
    bindingsMapper   = plotView.getBindingsMapper()
    annotationMapper = plotView.getAnnotationMapper()

    # extract the info first from the Plotview do it once vs. per
    # element to make things faster...  also its all done with
    # references so little memory overhead

    plotData = dataMapper.getPlotData()
    trace    = dataMapper.getLineTrace()
    weights  = dataMapper.getWeightsAndSytles()
    colors   = colorMapper.getColors('rgba')
    sizes    = markersMapper.getMarkerSizes()
    markers  = markersMapper.getPlotMarkers()
    
    primaryLabeling = annotationMapper.getPrimaryLabeling()
    markers  = markersMapper.getPlotMarkers()
    

    if primaryLabeling is None:
      rowNames = xrange(len(plotData))
    else:
      rowNames = primaryLabeling.getLabelByRows()

    if type(sizes) not in [types.ListType, nx.ArrayType, types.TupleType]:
      sizes  = [sizes] *len (plotData)
    #sizes = map(operator.add, map(str, sizes), ['i']*len(plotData))

    if type(markers) not in [types.ListType, nx.ArrayType, types.TupleType]:
      markers  = [markers] *len (plotData)

    # start plotting... 
    numCols = len(plotData[0])
    numRows = len(plotData)

    xdata = [ x[0] for x in plotData]
    ydata = [ y[1] for y in plotData]
    if colors is not None:
      plotArgs['color'] = colors
    if sizes is not None: plotArgs['s'] = sizes
    #if markers is not None: plotArgs['marker'] = markers
    #if rowNames is None:
    #  name = str(row)
    #else:
    #  name = str(rowNames[row])
    #plotArgs['label'] = name
    #  
    #if weights is not None:
    #  plotArgs['weights'] = weights[row][0]
    #  plotArgs['styles']  = weights[row][1]
    #
    #if trace is not None:
    #  plotArgs['trace'] = trace
    # draw the line
    collection = apply(self.axis.scatter, (xdata, ydata, ), plotArgs)
    collection._labels = primaryLabeling
    
    #bindingsMapper.addBindings()

    plotView.plotFinishings()
    self.currentPlotView = [plotView]
    self.show()

