#!/usr/bin/python

"""testIPlot is a base class used by all the various IPlot<toolkit> tests
"""
import os
import unittest

from compClust.mlx import datasets
from compClust.mlx import labelings

from compClust.iplot import IPlotAgg as IPlot
from compClust.iplot.views import DatasetRowPlotView

class DatasetPlotTestCases(unittest.TestCase):
    
  def testDatasetReplot(self):
    """Test to make sure that calling plot doesn't double the number of lines
    """
    ds = datasets.Dataset([[0,0,0,0],[1,1,1,1],[2,2,2,2],[3,3,3,3]])
    primary = labelings.Labeling(ds)
    primary.labelRows(['zero','one','two','three'])
    ds.primaryRowLabeling = primary
    row_view = DatasetRowPlotView(ds)
    iplot = IPlot.DatasetPlot(row_view)
    iplot.plot()
    
    lines = iplot.axis.get_lines()
    self.failUnless(len(lines) == ds.numRows)
    
  def testDatasetPlotNoLabels(self):
    """Test to make sure that the iplot axis labels for the dataset plot actually match the dataset labels
    """
    ds = datasets.Dataset([[0,0,0,0],[1,1,1,1],[2,2,2,2],[3,3,3,3]])
    row_view = DatasetRowPlotView(ds)
    iplot = IPlot.DatasetPlot(row_view)
    
    line_labels = [ l.get_label() for l in iplot.axis.get_lines() ]
    # in the case of no labels, a line just gets str(row_index)
    self.failUnless(line_labels == ['0','1','2','3'])
  
  def testDatasetPlotLabels(self):
    """Test to make sure that the iplot axis labels for the dataset plot actually match the dataset labels
    """
    data = {'zero': [0,0,0,0],
            'one': [1,1,1,1],
            'two': [2,2,2,2],
            'three': [3,3,3,3]}
            
    ds = datasets.Dataset(data.values())
    primary = labelings.Labeling(ds)
    primary.labelRows(data.keys())
    ds.primaryRowLabeling = primary
    row_view = DatasetRowPlotView(ds)
    iplot = IPlot.DatasetPlot(row_view)
    
    line_labels = [ l.get_label() for l in iplot.axis.get_lines()]
    self.failUnless(line_labels == data.keys())
    
    # test our original labeling
    for line_label in line_labels:
      # if we access data through the subset view do we get the same value
      subset_ds=labelings.subsetByLabeling(ds, primary, [line_label])
      self.failUnless(subset_ds.getData() == data[line_label])
      # if we access the data through a new plot, do we get the same values
      subset_view = DatasetRowPlotView(subset_ds)
      subset_plot = IPlot.DatasetPlot(subset_view)
      subset_lines = subset_plot.axis.get_lines()
      self.failUnless(len(subset_lines) == 1)
      for i in range(len(data[line_label])):
        self.failUnless(subset_lines[0].get_ydata()[i] == data[line_label][i])
      
    # see if the annotation mapper broke things
    mapped_labeling = row_view.getAnnotationMapper().getPrimaryLabeling()
    for line_label in line_labels:
      subset_ds=labelings.subsetByLabeling(ds, mapped_labeling, [line_label])
      for i in range(len(data[line_label])):
        self.failUnless(subset_ds.getData()[0][i] == data[line_label][i])
      
  def testDatasetSubsetPlot(self):
    data = {'zero': [0,0,0,0],
            'one': [1,1,1,1],
            'two': [2,2,2,2],
            'three': [3,3,3,3]}
            
    ds = datasets.Dataset(data.values())
    primary = labelings.Labeling(ds)
    primary.labelRows(data.keys())
    ds.primaryRowLabeling = primary
    high_low = labelings.Labeling(ds)
    high_low.labelRows(['low','low','high','high'])
    
    for subset_label in ['low', 'high']:
      subset_ds=labelings.subsetByLabeling(ds, high_low, [subset_label])
      self.failUnless(len(subset_ds.getData()) == 2)
    
def suite(*args, **kwargs):
  return unittest.makeSuite(DatasetPlotTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  