import copy
import operator
import os
import string
import urllib

from compClust.mlx import labelings
from compClust.util import DistanceMetrics
from compClust.util import FileIO

from compClust.iplot.views import DatasetRowPlotView
from compClust.iplot import IPlotAgg as IPlot

from quixote.errors import PublishError

class WebPlot:
  """WebPlot will lazily create an IPlot based plot with support for being put into web pages
  """
  def __init__(self, datasource):
    """Instantiate a WebPlot object
    
    :Parameters: 
      - `datasource`: a datasource object from compClust.gui.DataSource
    """
    self.datasource = datasource
    self.func = None
    self.args = ()
    self.kwargs = {}

    self.image_extension = ".png"
    self.__image_filename = None
    self.__imagemap = None
    self.__iplot = None

  def __cast_subset_label(self, subset_label, subset):
    """Our subset_label might not be a string
    """
    # validate our input
    if subset is None:
      if subset_label is not None:
        raise ValueError("Can't search None for a label")
      return None
    
    if subset_label not in subset.getLabels():
      # if we don't have subset_label, perhaps it needs to be casted?
      subset_label = FileIO._parseToken(subset_label)
      if subset_label not in subset.getLabels():
        # Hmm... this seems to be a bogus label
        msg = "%s was not found in %s %s"
        safe_msg = msg % (subset_label, subset.getName(), "")          
        private_msg = msg % (subset_label, subset.getName(), str(subset.getLabels()))
        raise PublishError(safe_msg, private_msg)
    return subset_label
    
  def __del__(self):
    """When we're destroyed try to delete our image file
    """
    if self.__image_filename is not None:
      os.remove(self.__image_filename)

  def __getstate__(self):
    """Return the state of the plot, blocking out the unserializable iplot object
    """
    state = self.__dict__.copy()
    if state.has_key('_WebPlot__iplot'):
      del state['_WebPlot__iplot']
    return state

  def __setstate__(self, d):
    """Set the state after being restored from being pickled
    """
    if not self.__dict__.has_key('_WebPlot__iplot'):
      self.__dict__['_WebPlot__iplot'] = None
    self.__dict__.update(d)
    
  def __get_iplot(self):
    """Get iplot object, creating it if necessary
    """
    if self.__iplot is None:
      self.__iplot = apply(getattr(self, self.func), self.args, self.kwargs)
    return self.__iplot
  iplot = property(__get_iplot, doc="returns iplot reference")

  def __get_image_filename(self):
    """Return a image filename, which as a side effect ends up saving the plot
    """
    if self.__image_filename is None or not os.path.exists(self.__image_filename):
      iplot = apply(getattr(self, self.func), self.args, self.kwargs)
      iplot.image_extension = self.image_extension
      self.__image_filename = iplot.image_filename
    return self.__image_filename
  image_filename = property(__get_image_filename, doc="get a filename to a rendering of the plot on disk")

  def sizeof(self):
    """Estimate memory usage"""
    size = 0
    if self.__imagemap is not None:
      return len(self.__imagemap)
    size +=reduce(operator.add, [len(x) for x in self.args])
    size +=reduce(operator.add, [len(x)+len(y) for x,y in self.kwargs.items()])
    
  # create image maps
  def get_axes_imagemap(self, url):
    """Try to create a client-side image map for a PlotPage
    (Allow the user to click on a plot summary to see a detail plot)
    """
    #if isinstance(self.plotclass, PlotPage):
    #  raise ValueError("axis imagemap only works on plots with multiple axes")
    if self.__imagemap is None:
      plot = self.iplot
      area = '<area href="%s" shape="rect" coords="%s,%s,%s,%s" title="%s" alt="%s" />'
      href_params = "%s/cluster_detail_plot?subset_label=%s&amp;set=%s"
       
      map = ['<map id="graph" name="graph">']
      for i in xrange(len(plot.figure.axes)):
        axis = plot.figure.axes[i]
        xmin = int(axis.bbox.xmin())
        xmax = int(axis.bbox.xmax())
        ymin = int(plot.figure.bbox.ymax() - axis.bbox.ymin())
        ymax = int(plot.figure.bbox.ymax() - axis.bbox.ymax())
  
        clusterLabeling = plot.labeling
        label = plot.axisToLabel.get(axis, None)
        # cache label object for later lookup
        href = href_params % (url, label, self.datasource.get_label_name(clusterLabeling))
        map.append(area % (href, xmin, ymax, xmax, ymin, str(label), str(label)))
    
      map.append('</map>')
      self.__imagemap = string.join(map, os.linesep)
    return self.__imagemap
  
  def get_data_imagemap(self, url, subset):
    """Create a client-side impage map for a detail plot
    """
    if self.__imagemap is None:
      plot = self.iplot
      subset = self.datasource.get_label_name(subset)
      lines = plot.axis.get_lines()
      area = '<area  href="%s" shape="circle" coords="%s,%s,%s" title="%s" alt="%s"/>'
      href_params = "%s/labelplot?subset=%s&amp;set=%s"
      map = ['<map id="graph" name="graph">']
      for l in lines:
        points = zip(l.get_xdata(), l.get_ydata())
        for data_point in points:
          coord_x, coord_y = plot.axis.transData.xy_tup(data_point)
          coord_x = int(coord_x)
          coord_y = int(plot.figure.bbox.ymax() - coord_y)
          subset = self.datasource.get_labeling_by_name(subset)
          label = l.get_label()
          href= href_params % (url, label, self.datasource.get_label_name(subset))
          point_label = str(label)
          map.append(area%(href, coord_x, coord_y,4,point_label, point_label))
      map.append('</map>')
      self.__imagemap = string.join(map, os.linesep)
    return self.__imagemap
  
  def get_scatter_imagemap(self, url, set):
    """Create a client-side impage map for a scatter plot
    """
    if self.__imagemap is None:
      plot = self.iplot
      #subset = self.datasource.get_label_name(subset)
      axis = plot.figure.axes[0]
      try:
        collections = axis.collections
      except AttributeError, e:
        # backwards compatibility with pre-0.70 matplotlib
        collections = axis._collections
      area = '<area  href="%s" shape="circle" coords="%s,%s,%s" title="%s" alt="%s"/>'
      href_params = "%s/labelplot?subset=%s&amp;set=%s"
      map = ['<map id="graph" name="graph">']
      for collection in collections:
        points = collection._offsets
        labels = collection._labels.getAllRowLabels()
        labeling_name = self.datasource.get_labeling_by_name(self.datasource.primary)
        labels = [ l[0] for l in labels ]
        for i in xrange(len(points)):
          data_point = points[i]
          coord_x, coord_y = axis.transData.xy_tup(data_point)
          coord_x = int(coord_x)
          coord_y = int(plot.figure.bbox.ymax() - coord_y)
          href= href_params % (url, labels[i], self.datasource.get_label_name(set))
          point_label = str(labels[i])
          map.append(area%(href, coord_x, coord_y,4,point_label, point_label))
      map.append('</map>')
      self.__imagemap = string.join(map, os.linesep)
    return self.__imagemap
  
  # plot creation code
  def dataset_plot(self, primary=None, secondary=None, subset_label=None, subset=None):
    """Add a dataset plot, if we're provided a subset and subset label
    construct a subset of the dataset
    """
    subset_label = self.__cast_subset_label(subset_label, subset)
    self.func = "create_dataset_plot"
    self.kwargs = {'primary': primary,
                   'secondary': secondary,
                   'subset_label': subset_label,
                   'subset': subset}
    self.id = str(('dataset_plot', primary, secondary, subset_label, subset))

  def create_dataset_plot(self, primary=None, secondary=None, subset_label=None, subset=None):
    """Create a dataset plot based on the parameters saved by self.dataset_plot
    """
    if primary is None:
      primary = self.datasource.get_labeling_by_name(self.datasource.primary)
    if secondary is None:
      secondary = self.datasource.get_labeling_by_name(self.datasource.secondary)
    dataset = self.datasource.dataset

    if subset_label is not None and subset is not None:
      dataset=labelings.subsetByLabeling(dataset, subset, [subset_label])
      dataset.setName(subset_label)
      #self.__plot.setTitle("Cluster %s\n %i elements"%(str(subset_labeling),subset.getNumRows()))
      
    if dataset.getNumRows() == 0 or dataset.getNumCols() == 0:
      msg = "Can't plot dataset %s with %d rows by %d columns"
      msg %= (dataset.name, dataset.getNumRows(), dataset.getNumCols())
      raise PublishError(msg)
    
    #make the row plot view
    pv = DatasetRowPlotView(dataset, primaryLabeling=primary, secondaryLabeling=secondary)

    #if color_label is not None:
    #  color_label = self.datasource._labeling_label_to_label(subset, subset_label)
    #  pv.getColorMapper().setColorsByLabeling(color_label)
    #else:
    pv.getColorMapper().setColorByColValue(0)
      
    iplot = IPlot.DatasetPlot(pv)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot

  def cluster_trajectories(self, set, primary=None, secondary=None):
    """Construct a trajectory summary broken out by set
    """
    self.func = "create_cluster_trajectories"
    self.kwargs = {'set': set,
                   'primary': primary,
                   'secondary': secondary}
    self.id = str(('cluster_trajectories', set, primary, secondary))
    
  def create_cluster_trajectories(self, set, primary=None, secondary=None):
    if primary is None:
      primary = self.datasource.get_labeling_by_name(self.datasource.primary)
    if secondary is None:
      secondary = self.datasource.get_labeling_by_name(self.datasource.secondary)

    iplot = IPlot.TrajectorySummary(self.datasource.dataset, clusterLabeling=set, primaryLabeling=primary, secondaryLabeling=secondary)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot

  def confusion_matrix_summary(self, labeling1, labeling2, primary=None, secondary=None):
    self.func = "create_confusion_matrix_summary"
    self.kwargs = {'labeling1': labeling1,
                   'labeling2': labeling2,
                   'primary': primary,
                   'secondary': secondary}
    self.id = str(('confusion_matrix', labeling1, labeling2, primary, secondary))

  def create_confusion_matrix_summary(self, labeling1, labeling2, primary=None, secondary=None):
    """Create confusion matrix summary
    """
    if primary is None:
      primary = self.datasource.get_labeling_by_name(self.datasource.primary)
    if secondary is None:
      secondary=self.datasource.get_labeling_by_name(self.datasource.secondary)

    iplot = IPlot.ConfusionMatrixSummary(self.datasource.dataset,
                                         labeling1=labeling1,
                                         labeling2=labeling2,
                                         primaryLabeling=primary,
                                         secondaryLabeling=secondary,
                                         web_safe=True)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot

  def roc_plot(self, subset_label, subset, distance=None):
    subset_label = self.__cast_subset_label(subset_label, subset)
    self.func = "create_roc_plot"
    self.kwargs = {'subset_label': subset_label,
                   'subset': subset,
                   'distance': distance}
    self.id = str(('roc', subset_label, subset, distance))
    
  def create_roc_plot(self, subset_label, subset, distance=None):
    if distance is None:
      distance = DistanceMetrics.EuclideanDistance
    
    iplot = IPlot.ROCPlot(self.datasource.dataset,
                          labeling=subset,
                          label=subset_label,
                          distanceMetric=distance)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot
  
  def pca_projection(self, xaxis, yaxis, primary=None, secondary=None):
    self.func = "create_pca_projection"
    self.kwargs = {'xaxis': xaxis,
                   'yaxis': yaxis,
                   'primary': primary,
                   'secondary': secondary}
    self.id = str(('pca_projection-%d-%d' % (xaxis, yaxis)))
    
  def create_pca_projection(self, xaxis, yaxis, primary=None, secondary=None):
    if primary is None:
      primary = self.datasource.get_labeling_by_name(self.datasource.primary)
    if secondary is None:
      secondary=self.datasource.get_labeling_by_name(self.datasource.secondary)
    iplot = IPlot.PCAPlot(self.datasource.dataset,
                          primaryLabeling=primary,
                          secondaryLabeling=secondary).ProjectionPlot()
    iplot.id = self.id
    iplot.drawProjection(xaxis, yaxis)
    self.__image_filename = iplot.image_filename
    return iplot
    
  def pca_eigenvector(self, primary=None, secondary=None):
    self.func = "create_pca_eigenvector"
    self.kwargs = {'primary': primary, 'secondary': secondary}
    self.id = str(('pca_eigenvector'))
    
  def create_pca_eigenvector(self, primary=None, secondary=None):
    if primary is None:
      primary = self.datasource.get_labeling_by_name(self.datasource.primary)
    if secondary is None:
      secondary=self.datasource.get_labeling_by_name(self.datasource.secondary)

    iplot = IPlot.PCAPlot(self.datasource.dataset,
                          primaryLabeling=primary,
                          secondaryLabeling=secondary).EigenVectorPlot()
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot

  def pc_vs_pc_with_extreme_genes(self, ginzu, pcNumX, pcNumY):
    """Create PCAGinzu plot plotPCNOutlierRowsInNativeColumnOrder
    
    :Parameters:
      - `ginzu`: a pca ginzu object (we don't want to create it here since 
                    it has a number of options and can be a bit slow)
      - `pcNumX`: which principal comonent to look at on the X axis
      - `pcNumY`: which principal comonent to look at on the Y axis
    """
    self.func = "create_pc_vs_pc_with_extreme_genes"
    self.kwargs={'ginzu': ginzu, 
                 'pcNumX':pcNumX,
                 'pcNumY':pcNumY,}
    self.id = str(('pc_vs_pc_with_extreme_genes-%d-%d' % (pcNumX, pcNumY)))
    
  def create_pc_vs_pc_with_extreme_genes(self, ginzu, pcNumX, pcNumY):
    """(actually create the extreme gene plot once the application is ready for it.
    """
    iplot = ginzu.plotPCvsPCWithOutliersInY(pcNumX, pcNumY)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot
  
  def pceg_in_significance_order(self, ginzu, pcNum):
    """Create PCAGinzu plot plotPCNOutlierRowsInOriginalColumnOrder
    
    :Parameters:
      - `ginzu`: a pca ginzu object (we don't want to create it here since 
                    it has a number of options and can be a bit slow)
      - `pcNum`: which principal comonent to look at
    """
    self.func = "create_pceg_in_significance_order"
    self.kwargs={'ginzu': ginzu, 
                 'pcNum':pcNum,}
    self.id = str(('pceg_in_significance_order-%d', pcNum))
    
  def create_pceg_in_significance_order(self, ginzu, pcNum):
    """(actually create the extreme gene plot once the application is ready for it.
    """
    iplot = ginzu.plotPCNOutlierRowsInSigGroupOrder(pcNum)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot
       
  def pceg_in_native_order(self, ginzu, pcNum):
    """Create PCAGinzu plot plotPCNOutlierRowsInNativeColumnOrder
    
    :Parameters:
      - `ginzu`: a pca ginzu object (we don't want to create it here since 
                    it has a number of options and can be a bit slow)
      - `pcNum`: which principal comonent to look at
    """
    self.func = "create_pceg_in_native_order"
    self.kwargs={'ginzu': ginzu, 
                 'pcNum':pcNum,}
    self.id = str(('pceg_in_native_order-%d', pcNum))
    
  def create_pceg_in_native_order(self, ginzu, pcNum):
    """(actually create the extreme gene plot once the application is ready for it.
    """
    iplot = ginzu.plotPCNOutlierRowsInOriginalColumnOrder(pcNum)
    iplot.id = self.id
    self.__image_filename = iplot.image_filename
    return iplot
       