
import matplotlib.numerix as nx

import MLab

from compClust.mlx.views import RowSubsetView
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import GlobalWrapper

from compClust.score import roc
from compClust.util import DistanceMetrics
from compClust.util import Histogram
from compClust.util import listOps

from PlotPage import PlotPage
from HistogramPlotter import HistogramPlotter

class ROCPlot(PlotPage):

  def __init__(self, canvasFactory, dataset, labeling, label, distanceMetric=DistanceMetrics.EuclideanDistance):
   """
   draw our standard ROC plot for the label in labeling for dataset 
   """
   PlotPage.__init__(self, 1, 2, canvasFactory=canvasFactory)

   # get the ROC stats and start a plotpage
   area, xvalues, yvalues = roc.clusterROC(dataset, labeling, label, distanceMetric=distanceMetric)
   # FIXME: need add title
   #pp.sf.configure(label_text='ROC Display for label %s in %s\n Area=%3.2f'%(str(label), str(labeling.getName()),area))
   
   # make the ROC curve
   area_axis = self.addPlot(0,0)
   area_axis.fill(xvalues.tolist()+[1,0], yvalues.tolist()+[0,0])
   area_axis.set_title('ROC Curve (area=%3.2f)'%(area), fontdict=self.title_font)
   area_axis.set_xlabel('% outside')
   area_axis.set_ylabel('% inside')

   ## build the interactive histogram plots
   ### These computations are partially redundent with the ROC code..   **
   data = dataset.getData()
   insideRows = labeling.getRowsByLabel(label)
   outsideRows = listOps.difference(range(dataset.getNumRows()), insideRows)
   clusterMean = MLab.mean(nx.take(data, insideRows))
   distances = distanceMetric(clusterMean, data)
   l = Histogram.binOnRowVector(dataset, distances, max(dataset.getNumRows()*0.1, 3))
   distanceLab = GlobalLabeling(dataset, '__distances__')
   ## FIXME globalLabelings.labelFrom is broken 
   for label in l.getLabels():
     distanceLab.addLabelToRows(dataset, label, l.getRowsByLabel(label))
   
   dataset.removeLabeling(l)
   insideView = RowSubsetView(dataset, insideRows)
   outsideView =RowSubsetView(dataset, outsideRows)
   insideLab = GlobalWrapper(insideView, glabeling=distanceLab)
   outsideLab = GlobalWrapper(outsideView, glabeling=distanceLab)
   
   ## FIXME this violates the plot page a little to make this a non-sequare plot 
   #histogram_axis = self.addPlot(0,1)
   histogram_axis = self.canvas.figure.add_subplot(222)
   self.plots[(0,1)] = histogram_axis
   histograms = HistogramPlotter(canvasFactory=self.canvasFactory, axis=histogram_axis)
   histograms.addLabeling(insideLab, seriesName='Inside')
   histograms.plot(fixedColor='r')
   histogram_axis.set_title('Distance Histogram for Cluster Members', fontdict=self.title_font) 
   histogram_axis.set_xlabel('Distance')
   histogram_axis.set_ylabel('Count')
   
   histogram_axis = self.canvas.figure.add_subplot(224)
   self.plots[(0,2)] = histogram_axis
   histograms = HistogramPlotter(canvasFactory=self.canvasFactory, axis=histogram_axis)
   histograms.addLabeling(outsideLab, seriesName='Outside')
   histograms.plot(fixedColor='b')
   histogram_axis.set_title('Distance Histogram for Cluster Members', fontdict=self.title_font) 
   histogram_axis.set_xlabel('Distance')
   histogram_axis.set_ylabel('Count')
   self.show()
