import inspect
import os
import unittest
import Numeric

from quixote.errors import PublishError

from compClust.mlx import datasets
from compClust.mlx import labelings
import compClust.gui.DataManager as DataManager
from compClust.gui.DataSource import DataSource
from compClust.gui.LabelingSource import LabelingSource
from compClust.gui.PlotCache import PlotCache

class PlotTestCases(unittest.TestCase):
  def testDatasetPlotSubsetLabelCast(self):
    data = datasets.Dataset([[1,1,1],[2,2,2],[3,3,3]])
    source = DataSource(data)
    row_labeling = labelings.Labeling(data)
    row_labeling.labelRows([1,2,3])
    
    def get_plot_subset_label(subset_label):
      plot = source.dataset_plot(primary=row_labeling, 
                                 subset_label=subset_label, 
                                 subset=row_labeling)
      return plot.kwargs['subset_label']
    
    self.failUnless(get_plot_subset_label(1) == 1)
    self.failUnless(get_plot_subset_label('1') == 1)
    self.failUnlessRaises(PublishError, get_plot_subset_label, 'bleem')

  def testDatasetPlotLabelNamesSimple(self):
    data = datasets.Dataset([[1,1,1],[2,2,2],[3,3,3]])
    source = DataSource(data)
    row_labeling = labelings.Labeling(data, 'row_id')
    #row_labeling.labelRows([1,2,3])
    row_source = source.add_labeling('row_id', [1,2,3], True, True)
    source.primary = row_source

    plot = source.dataset_plot()
    iplot = plot.iplot
    for line in iplot.axis.get_lines():
      self.failUnlessAlmostEqual(line.get_ydata()[0], float(line.get_label()))
    
  def testDatasetPlotLabelNamesCho(self):
    from compClust.util.LoadExample import LoadChoSource
    source = LoadChoSource()
    orfs = source.dataset.getLabeling(source.primary.name)
    
    full_plot = source.dataset_plot()
    full_iplot = full_plot.iplot
    
    for line in full_iplot.axis.get_lines():
      row_keys = orfs.getRowsByLabel(line.get_label())
      self.failUnless(len(row_keys) == 1)
      data = source.dataset.getRowData(row_keys[0])
      line_data = line.get_ydata()
      for point in range(len(data)):
        self.failUnlessAlmostEqual(line_data[point], data[point])

    diagem = source.dataset.getLabeling('em')
    self.failUnless(diagem is not None)
    subset_plot = source.dataset_plot(subset_label=5, subset=diagem)
    subset_iplot = subset_plot.iplot
    subset_view = subset_iplot.currentPlotView[0].getDataset()
    subset_labels = subset_view.getLabeling('orfs').getLabelByRows()
    
    for line in subset_iplot.axis.get_lines():
      row_keys = orfs.getKeysByLabel(line.get_label())
      self.failUnless(len(row_keys) == 1)
      data = source.dataset.getRowData(row_keys[0])
      line_data = line.get_ydata()
      for point in range(len(data)):
        self.failUnlessAlmostEqual(line_data[point], data[point])

  def testRocPlotSubsetLabelCast(self):
    data = datasets.Dataset([[1,1,1],[2,2,2],[3,3,3]])
    source = DataSource(data)
    row_labeling = labelings.Labeling(data)
    row_labeling.labelRows([1,2,3])
    
    def get_plot_subset_label(subset_label):
      plot = source.roc_plot(subset_label=subset_label, subset=row_labeling)
      return plot.kwargs['subset_label']
    
    self.failUnless(get_plot_subset_label(1) == 1)
    self.failUnless(get_plot_subset_label('1') == 1)
    self.failUnlessRaises(PublishError, get_plot_subset_label, 'bleem')

  def testPCAOutliersPlotCho(self):
    return
    from compClust.gui.CompClustWebSession import load_cho
    source = load_cho()
    orfs = source.dataset.getLabeling(source.primary.name)
    plot = source.pca_outliers_in_native_order(2)
    iplot = plot.iplot
    
    for line in iplot.axis.get_lines():
      row_keys = orfs.getRowsByLabel(line.get_label())
      self.failUnless(len(row_keys) == 1)
      data = source.dataset.getRowData(row_keys[0])
      line_data = line.get_ydata()
      for point in range(len(data)):
        self.failUnlessAlmostEqual(line_data[point], data[point])

    
    

def suite(**kw):
 suite = unittest.makeSuite(PlotTestCases)
 return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")

