import md5
import operator
import os
import string
import types

from compClust.mlx import datasets
from compClust.mlx import labelings
from compClust.util import DistanceMetrics

from compClust.gui.LabelingSource import LabelingSource
from compClust.gui.PlotCache import PlotCache
from compClust.gui.WebPlot import WebPlot

from compClust.iplot.IPlotAgg import PCAGinzu

class DataSource(object):
  """Manage a dataset and its associated meta information
  """
  def __init__(self, dataset=None, delimiter="\t"):
    self.__dataset = None
    self.__dataset_name = None
    self.__dataset_source = None
    self.__delimiter = delimiter # save the character used to seperate fields in the dataset with
    self._labeling_sources = {}
    self.__primary = None
    self.__secondary = None
    self.__id = None
    self.cache_dir = None
    self.is_cached_file = False  # did we write out the file?

    # stuff for handling plotting
    self.plots = PlotCache()
    
    # ginzu
    self.__ginzu = None
    
    self.__set_dataset(dataset)

  ##########
  # persistance
  def __getstate__(self):
    serial_dict = self.__dict__.copy()
    # ditch our copy of dataset if we can load it again.
    if self.__dataset is not None:
      ## FIXME: when we jettison the dataset when constantly serializing
      ## FIXME: plots don't show up, so i stopped unloading the dataset.
      ## FIXME: know the question is when do we unload it?
      ## if we can reload from somewhere jettison the dataset
      #if self.__dataset_source is not None:
      #  self.__dataset = None
      # if we can't reload the dataset see if we can save it.
      if self.__dataset_source is not None and self.cache_dir is not None:
        persist_dir = os.path.join(self.cache_dir, self.id)
        if not os.path.exists(persist_dir):
          os.mkdir(persist_dir)
        if not os.path.isdir(persist_dir):
          raise OSError("couldn't find directory to store dataset")
        self.__dataset_source = os.path.join(persist_dir, 'dataset')
        self.__dataset.writeDataset(self.__dataset_source)
        self.__dataset = None
        self.is_cached_file = True
    return serial_dict

  # manage reference to dataset
  def __get_dataset(self):
    if self.__dataset is None and self.__dataset_source is not None:
      # instantiate ourselves
      self.__dataset = datasets.Dataset(self.__dataset_source)
      self.__attach_labelings()
    return self.__dataset
  def __set_dataset(self, dataset):
    # allow unloading the dataset
    if dataset is None:
      # perhaps unload the current dataset?
      self.__dataset = None
    else:
      # load dataset, throw an error if Dataset couldn't load it
      if isinstance(dataset, datasets.Dataset):
        self.__dataset = dataset
        self.merge_dataset_labelings()
      else:
        self.__dataset = datasets.Dataset(dataset)
      if self.__dataset is None:
        raise ValueError("Couldn't construct dataset")
      # if it's a string (preferably a url), save where it came from
      elif type(dataset) in types.StringTypes:
        self.__dataset_source = dataset
      # attach labelings
      self.__attach_labelings()
  def __del_dataset(self):
    self.__dataset = None
  dataset = property(__get_dataset, __set_dataset, __del_dataset, "stores reference to mlx dataset")

  def sizeof(self):
    """estimate memory allocation
    """
    size = 0
    # size of dataset
    if self.__dataset is not None:
      size +=self.__dataset.sizeof()
    # size of labelings
    
    return size + self.plots.sizeof()
  
  # information about the dataset
  def __is_loaded(self):
    return self.__dataset != None
  is_loaded = property(__is_loaded, doc="is the dataset loaded")

  # specify dataset source
  def __get_source(self):
    return self.__dataset_source
  def __set_source(self, source):
    if self.__dataset_source is not None:
      raise ValueError("a dataset can only be created once")
    if source is None:
      raise ValueError("must specify a valid location for a dataset")
    self.__dataset_source = source
  def __del_source(self):
    self.__dataset_source = None
  source = property(__get_source, __set_source, __del_source, "Reference to where to load a dataset from")

  # heuristic to make dataset id
  def __get_id(self):
    if self.__id is None:
      if self.__dataset_source is not None:
        self.__id = self.getShortenedName()
        #self.__id = md5.md5(self.__dataset_source).hexdigest()
      elif self.dataset is not None:
        sum = md5.md5()
        for row in self.dataset.getData():
          sum.update(str(row))
        self.__id = sum.hexdigest()
    return self.__id
  id = property(__get_id, doc="create a relatively unique id")
  
  # heuristic to generate name of dataset  
  def __get_name(self):
    if self.__dataset_name is None:
      if self.__dataset is not None and self.dataset.getName() is not None:
        self.name = self.dataset.getName()
      elif self.__dataset_source is not None:
        self.name = self.__dataset_source
      else:
        return None
    return self.__dataset_name
  def __set_name(self, name):
    self.__dataset_name = name
    if self.__dataset is not None:
      self.dataset.setName(name)
  name = property(__get_name, __set_name, None, "The name of the dataset")

  def getShortenedName(self, length=20):
    dsname = self.name
    if dsname is None:
      return self.id
    elif len(dsname) > length:
      dsname = os.path.split(dsname)[1]
      if dsname > length:
        dsname = dsname[:length]
    return dsname
  
  ##manage annotation labelings
  def __get_primary(self):
    return self.__primary
  def __set_primary(self, labeling):
    if isinstance(labeling, LabelingSource):
      self.__primary = labeling
    else:
      self.__primary = self._labeling_sources[labeling]
  primary = property(__get_primary, __set_primary, doc="store default primary labeing for this data source")

  def __get_secondary(self):
    return self.__secondary
  def __set_secondary(self, labeling):
    if isinstance(labeling, LabelingSource):
      self.__secondary = labeling
    else:
      self.__secondary = self._labeling_sources[labeling]
  secondary = property(__get_secondary, __set_secondary, doc="store default secondary labeing for this data source")
  
  # manage labeling sources
  def merge_dataset_labelings(self):
    """If the attached dataset has labelings make them available to the DataSource
    """
    if self.__dataset is None:
      return
    
    for label in self.__dataset.getLabelings():
      # don't copy unnamed labelings
      if label.getName() is None:
        continue
      
      isRow = label.isRowLabeling()
      isCol = label.isColLabeling()
      # skip labelings that are either empty or have both rows and columns
      if not (isRow ^ isCol):
        continue
      
      # if we already have this labeling as a labeling source, skip it
      if self._labeling_sources.has_key(label.getName()):
        continue
        
      labeling_source = LabelingSource(label.getName(), label, isRow, True)
      self._labeling_sources[label.getName()] = labeling_source
      
  def add_labeling(self, name, source, isrow=True, isannotation=None, lookup_url=None, description=None):
    """Add a labeling to this dataset source (and return the currently
    added item)

    If source is pointing to a FieldStorage file handle, save it somewhere.
    (as I'm assuming any fieldstorage file is temporary)
    """
    if self._labeling_sources.has_key(name):
      raise ValueError("label names must be unique: %s" % (name))

    labeling_source = LabelingSource(name, source, isrow, isannotation, lookup_url, description)
    self._labeling_sources[name] = labeling_source
    
    # if we have a dataset update
    if self.__dataset is not None:
      self.__attach_label(name, labeling_source)

    return labeling_source

  def __attach_label(self, label_name, labeling_source):
    """Attach a labeling to the dataset
    """
    label = labelings.GlobalLabeling(self.dataset, label_name)
    if labeling_source.isrow:
      label.labelRows(self.dataset, labeling_source.source)
    else:
      label.labelCols(self.dataset, labeling_source.source)
    
  def __attach_labelings(self):
    """Once a dataset is loaded this will attach all the currently
    defined labeling sources to it.
    """
    # load list of currently loaded labelings
    loaded_labelings = {}
    for labelings in self.dataset.getLabelings():
      loaded_labelings[labelings.getName()] = labelings
      
    for label_name, labeling_source in self._labeling_sources.items():
      if loaded_labelings.has_key(label_name):
        continue
      self.__attach_label(label_name, labeling_source)
      
  def get_labeling_by_name(self, name):
    """Returning underlying dataset labeling by a labeling name
    (or anything that can reasonably be transformed into a labeling)
    """
    if name is None:
      return None
    elif isinstance(name, labelings.Labeling):
      return name
    elif isinstance(name, LabelingSource):
      return self.dataset.getLabeling(name.name)
    
    return self.dataset.getLabeling(name)

  def has_labeling_by_name(self, name):
    try:
      self.dataset.getLabeling(name)
      return True
    except KeyError, e:
      return False
    
  
  def get_labeling_names(self):
    """Return list of all the names of labelings attached to this dataset
    """
    if self.dataset is None:
      return None
    else:
      labels = [ l.getName() for l in self.dataset.getLabelings() ]
      labels.sort()
      return labels

  def get_labelingsource_names(self):
    """return list of all the labeling sources attached to this datasource
    (should be a subset of get_labeling_names
    """
    names = self._labeling_sources.keys()
    names.sort()
    return names
    

  def get_labeling_source(self, labeling):
    """Attempt to get a labelsource object for the specified labeling
    (mostly useful to get the lookup url 
    """
    return self._labeling_sources[labeling]

  def get_labeling_sources(self):
    """Return list of all labeling sources attached to this datasource
    """
    return self._labeling_sources.values()

  def get_labelings_by_primary(self, primaryKey):
    """Get all the labelings attached to a particular slice defined by the
    primary labeling.
    """
    primaryLabeling = self.dataset.getLabeling(self.primary.name)
    row = primaryLabeling.getRowsByLabel(primaryKey)

    if len(row) > 1:
      raise ValueError("non-unique labeling")
    elif len(row) == 0:
      raise ValueError("label %s not found in %s" %(primaryKey, primaryLabeling.getName()))
    row = row[0]
    return [ (l, l.getLabelsByRow(row)[0])
               for l in self.dataset.getLabelings()
                 if len(l.getLabelsByRow(row)) > 0]

  def get_labelingsource_by_id(self, labelingsourceid):
    """Get labeling source by a labeling source id
    """
    for l in self.get_labeling_sources():
      if l.id == labelingsourceid:
        return l
    return None

  def get_labelingsource_by_labeling(self, labeling):
    """Return a labeling source for the provided mlx labeling.
    """
    for ls in self._labeling_sources.values():
      if ls.labeling == labeling:
        return ls
    else:
      raise KeyError("the labeling is not attached to this dataset source")
    
  def get_labels_for_labeling(self, labeling_name):
    """Get all the labels in a particular labeling
    """
    if isinstance(labeling_name, LabelingSource):
      # lookup by labelingsource object
      labeling_name = labeling_name.name
    elif self._labeling_sources.has_key(labeling_name):
      # lookup by id
      labeling_name = self._labeling_sources[labeling_name].name
    # else, I hope we have a name
    labeling = self.dataset.getLabeling(labeling_name)

    return labeling.getLabels()

  ############
  # Things specific for providing web access
  def get_label_name(self, label):
    if isinstance(label, labelings.Labeling):
      return label.getName()
    elif isinstance(label, LabelingSource):
      return label.name
    else: 
      return str(label)
    

  #############
  # Plot Management
  def __len__(self):
    """Return number of plots we have set
    """
    return len(self.plots)
    
  def dataset_plot(self, primary=None, secondary=None, subset_label=None, subset=None):
    p = WebPlot(self)
    p.dataset_plot(primary, secondary, subset_label, subset)
    return self.plots.setdefault(p.id, p)
    
  def cluster_trajectories(self, set, primary=None, secondary=None):
    if not isinstance(set, labelings.Labeling):
      raise ValueError("we need a labeling to subset our summary on")
    p = WebPlot(self)
    p.cluster_trajectories(set, primary, secondary)
    return self.plots.setdefault(p.id, p)

  def confusion_matrix_summary(self, labeling1, labeling2, primary=None, secondary=None):
    p = WebPlot(self)
    p.confusion_matrix_summary(labeling1, labeling2, primary, secondary)
    return self.plots.setdefault(p.id, p)

  def roc_plot(self, subset_label, subset, distance=None):
    p = WebPlot(self)
    p.roc_plot(subset_label, subset, distance)
    return self.plots.setdefault(p.id, p)
    
  def pca_projection(self, primary=None, secondary=None):
    p = WebPlot(self)
    p.pca_projection(primary, secondary)
    return self.plots.setdefault(p.id, p)

  def pca_eigenvector(self, primary=None, secondary=None):
    p = WebPlot(self)
    p.pca_eigenvector(primary, secondary)
    return self.plots.setdefault(p.id, p)


  #####
  # PCA EG support
  def __get_pceg_analysis(self):
    if self.__ginzu is None:
      # FIXME: There are to many different ways of passing around
      # FIXME: the annotation labelings, there's attaching them to the
      # FIXME: dataset as (primaryRow and primaryCol) theres
      # FIXME: attaching them to the datasource as labeling sources
      # FIXME: and of course there's just passing them as paramters
      # FIXME: this code really needs to pick one of them of stick to
      # FIXME: it. (unfortunately under a deadline isn't the time to 
      # FIXME: fix that, which is why theres this rant)
      primary = self.get_labeling_by_name(self.primary)
      secondary = self.get_labeling_by_name(self.secondary)
      self.__ginzu = PCAGinzu(self.dataset, primaryLabeling=primary, secondaryLabeling=secondary)
    return self.__ginzu
  def __set_pceg_analysis(self, ginzu):
    if isinstance(ginzu, PCAGinzu):
      self.__ginzu = ginzu
    else:
      raise ValueError("Expected instance of PCAGinzu class")
  pceg_analysis = property(__get_pceg_analysis, __set_pceg_analysis, doc="The PCEG Analysis object")
    
  def pceg_in_sig_order(self, pcNum, primary=None, secondary=None):
    """
    Create WebPlot of PCA Ginzu extreme genes sorted by significance
    
    :Parameters: 
      - `pcNum`: 0 based index for which principal component we should 
                 ivestigate
      - `primary`: primary labeling for plot annotations
      - `secondary`: secondary plot labeling for plot annotations
    """
    p = WebPlot(self)
    p.pceg_in_significance_order(self.pceg_analysis, pcNum)
    return self.plots.setdefault(p.id, p)

  def pceg_in_native_order(self, pcNum):
    """
    Create WebPlot of PCA Ginzu PCEGs in original data order
    
    :Parameters: 
      - `pcNum`: 0 based index for which principal component we should 
                 ivestigate
    """
    p = WebPlot(self)
    p.pceg_in_native_order(self.pceg_analysis, pcNum)
    return self.plots.setdefault(p.id, p)
  
  def pc_vs_pc_with_extreme_genes(self, pcNumX, pcNumY):
    """
    Create scatter plot of two PCA Ginzu components highlighting extreme genes in X
    
    :Parameters: 
      - `pcNumX`: 0 based index for which principal component we should 
                  display on the X axis of the scatter plot.
      - `pcNumY`: 0 based index for which principal component we should 
                  display on the X axis of the scatter plot.
    """
    p = WebPlot(self)
    p.pc_vs_pc_with_extreme_genes(self.pceg_analysis, pcNumX, pcNumY)
    return self.plots.setdefault(p.id, p)

  ####
  # Look up plots
  def get_plot_by_id(self, plotid):
    return self.plots.get(plotid, None)


