import types

from matplotlib.lines import lineMarkers

from compClust.mlx import labelings
from compClust.mlx.views import RowPCAView
from compClust.iplot.views import IPlotView
from compClust.iplot.mappers.Mapper import Mapper
from compClust.iplot.mappers.AnnotationMapper import AnnotationMapper
from compClust.iplot.mappers.BindingsMapper import BindingsMapper

class IDataMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how a dataset is projected onto the graph
  """

  def getPlotData(self):

    """
    plotData(self)
    
    returns a list of X,Y tuple Tuples.

    example: 
    [ ((x1,x2,x3,x4), (y1,y2,y3,y4))
      ((x1,x2,x3,x4), (y1,y2,y3,y4)) ]
    
    """
    raise NotImplementedError()
  
  def getLineTrace(self):
    """
    getLineStyle(self)
    
    retuns how the points when be connected...  If None, then the BLT
    default line connectors are used.

    Must be:
      'increasing', 'decreasing', 'both' or  None
    """

    return(None)
  
  def getWeightsAndSytles(self):

    """
    getWeights (self)

    This allows for marker control on a datapoint by datapoint
    basis...  If you want control on an element by element basis use
    getPlotMarkes and getMarkerSize functions inside of the
    MarkersMappers.
    
    returns either None or a list of tuples equal in length to the
    plotData returns value where the tuple is the weights tuple and
    the cooresponding styles tuple (see the BLT documentation for more
    information (this is very powerful)
    """

    return(None)

class RowDataMapper(IDataMapper):
  
  """
  RowDataMapper(Mapper):

  A simple data mapper which plots each row of the dataset has a line
  and seperate element.

  """

  def __init__(self, plotView, xAxisLabeling=None):

    super(IDataMapper, self).__init__(plotView)
    self.setXAxisLabeling(xAxisLabeling)

  def setXAxisLabeling(self, labeling=None):
    """
    setXAxisLabeling(self, labeling=None)

    sets the xvalues for the plot.  If no labeling is specified then
    the x-axis values are set to be the index. 
    """

    if labeling is not None:
      self.__xvalues = tuple(labeling.getLabelByCols())
    else:
      self.__xvalues = tuple(range(0, self.dataset.getNumCols()))

  def getXValues(self):
    return(self.__xvalues)

  def getPlotData(self):

    ds = self.dataset
    data = zip((self.getXValues(),) * ds.getNumRows(), map(tuple, ds.getData()))
    return(data)

class SortedRowDataMapper(RowDataMapper):
  """
  SortedRowDataMapper(Mapper):

  datamapper which plots each row according to a sort-order specified
  """
  def __init__(self, plotView, xAxisLabeling=None):
    super(SortedRowDataMapper, self).__init__(plotView)
    self.setXAxisLabeling(xAxisLabeling)
    self.sortedDataset = SortedView(self.dataset)

  def setXAxisLabeling(self, labeling=None):
    """
    sets xAxislabeling to be the globalWrapper labeling for the unsortedDataset
    """

    self.xAxisLabeling = labeling

  def sortByLabeling(self, labeling=None):
    """
    sets column-labeling on which sorting is performed
    """

    self.sortLabeling=labeling
    labels = labeling.getLabels()
    sortedCols = []
    for label in labels:
      sortedCols += self.sortLabeling.getColsByLabel(label)
    
    self.sortedDataset.permuteCols(sortedCols)

  def getPlotData(self):
    ds = self.sortedDataset
    xl = self.xAxisLabeling
    if xl:
      labeling = labelings.GlobalWrapper(ds, glabeling=self.xAxisLabeling.g)
      self.__xvalues = tuple(labeling.getLabelByCols())      
    else:
      labeling = None
      self.__xvalues = tuple(range(0, ds.getNumCols()))
      
    data = []
    for row in range(ds.getNumRows()):
      data.append((self.__xvalues, tuple(ds.getRowData(row))))

    return data

class MeanDataMapper(RowDataMapper, AnnotationMapper, BindingsMapper):

  """
  MeanStdDataMapper provides a view to look at the mean and std of
  either the whole datasets of groupings as specified by a labeling
  """

  def __init__(self, plotView, labeling=None):

    super(MeanDataMapper, self).__init__(plotView)
    self.__showStd = 0
    self.setLabeling(labeling)

  def setLabeling(self, labeling=None):
    """
    setLabeling(self, labeling)

    set the labeling to use when partitioning the dataset.
    """

    if labeling is None:
      if self.dataset.getLabeling('__meanDataMapper__'):
        labeling = self.dataset.getLabeling('__meanDataMapper__')
      else:
        labeling = Labeling(self.dataset, '__meanDataMapper__')
        labeling.labelRows  ([0]*self.dataset.getNumRows())
      self.__meanDataset = RowAggregateFunctionView(self.dataset, labeling, nx.mlab.mean)
      self.__stdDataset  = RowAggregateFunctionView(self.dataset, labeling, safeStdDev)
    else:
      self.__meanDataset.setLabeling(labeling)
      self.__stdDataset.setLabeling(labeling)

    self.dataset = self.__meanDataset

  def getLineTrace(self):
    return('increasing')

  def getPlotData(self):
    ds = self.dataset
    if self.__showStd:
      meanData = self.__meanDataset.getData()
      stdData = self.__stdDataset.getData()
      data = zip([self.getXValues()]*3, map(tuple, nx.concatenate(((meanData - stdData ) ,
                                                                        meanData,
                                                                        (meanData+stdData)))))
    else:
      data = zip((self.getXValues(),) * ds.getNumRows(), map(tuple, ds.getData()))
      
    return(data)
    
  def showStdDev(self, state=None):

    """
    showStdDev(self, state)
    
    """
    if state == None:
      self.__showStd = not self.__showStd
    else:
      self.__showStd = state

class ColumnScatterDataMapper(IDataMapper):

  """
  ColumnScatterPlotView plots two columns of a dataset against
  each other.
  """

  def __init__(self, plotView, xCol=0, yCol=1):

    super(ColumnScatterDataMapper, self).__init__(plotView)
    self.setXColumn(xCol)
    self.setYColumn(yCol)
      
  def getPlotData(self):
    return(zip(self.dataset.getColData(self.__xCol), self.dataset.getColData(self.__yCol)))

  def getXColumn(self):
    """
    returns the column of the date being plotting as X-data
    """
    return(self.__xCol)
    
  def getYColumn(self):
    """
    returns the column of the data being plotting as Y-data 
    """
    return(self.__yCol)

  def setXColumn(self, col):
    """
    setXColumn(self, row)

    sets row to be the column of the data to be plotting as X-data
    """
    self.__xCol = col
    
  def setYColumn(self, col):
    """
    setXColumn(self, col)
    
    sets row to be the column of the data to be plotting as X-data
    """
    self.__yCol = col


class PCADataMapper(IDataMapper):

  """
  ColumnScatterPlotView plots two columns of a dataset against
  each other.
  """

  def __init__(self, plotView, xCol=0, yCol=1):

    super(PCADataMapper, self).__init__(plotView)

    self.dataset.addViewDefault('RowPCAView', RowPCAView, plotView.getDataset())
    self.setXColumn(xCol)
    self.setYColumn(yCol)
      
  def getPlotData(self):
    return(zip(self.dataset.getColData(self.__xCol), self.dataset.getColData(self.__yCol)))

  def getXColumn(self):
    return(self.__xCol)

  def getYColumn(self):
    return(self.__yCol)


  def setXColumn(self, col):
    """
    setXColumn(self, row)

    sets row to be the column of the data to be plotting as X-data
    """
    self.__xCol = col
    
  def setYColumn(self, col):
    """
    setXColumn(self, col)
    
    sets row to be the column of the data to be plotting as X-data
    """
    self.__yCol = col


