########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 29-Apr-2002, 11:00
#

"""
A Global Labeling is a pervasive Labeling available to all the Datsets/Views
in a Hierarchy.
"""

#
# Design docs:
#
# A single instance of a GLabeling can have multiple datasets to which it is
# attatched.
#
# A GLabeling may only attatch to root datasets.
#

import operator
import string
import sys
import types

from Labeling import Labeling, LabelingDatasetLengthError
from compClust.util.unique import unique

def castToGlobalWrapper(lab, name=None, dataset=None, removeLocal=True):
  """cast the given labeling to a global wrapper which act
  , if name is none the name of the local 
     labeling is percolated to the global labeling.  if removeLocal is set, the local 
     labeling is removed.  A globalWrapper analogous to the lab is returned. """
  # if our GlobalWrapper is attached to the wrong view, try to convert it
  if isinstance(lab, GlobalWrapper):
    if lab.v == dataset:
      return lab
    else:
      return GlobalWrapper(dataset, lab.name, lab.g)
      
  if dataset is None:
    dataset = lab.getDataset()
  if name is None:
    name = lab.getName()
  elif type(name) not in types.StringTypes:
    raise ValueError("name must be a string")
  l = GlobalWrapper(dataset,name)
  l.labelFrom(lab)
  if removeLocal:
    dataset.removeLabeling(lab)
  return(l)


def castToGlobalLabeling(ds, labeling, name=None):
  """cast the given labeling to a global label, if name is none the name of the local 
     labeling is percolated to the global labeling.  if removeLocal is set, the local 
     labeling is removed.  A globalWrapper analogous to the lab is returned. """

  gw = None
  
  if isinstance(labeling, Labeling):

    if name is None:
      name= ds.getName()
    gl = GlobalLabeling(ds, name)
    for label in labeling.getLabels():
      keys = labeling.getKeysByLabel(label)
      gl.addLabelToKeys(ds,label, keys)
   
  return gl

    
class GlobalLabeling(Labeling):

  """
  A GlobalLabeling object can only attatch itself to a non-BaseView derived
  class. Once attatched, all Views decended from the Dataset instance will
  show the Labelings in their list of labelings.

  A global labels
  """

  def __init__(self, dataset, name=None):

    self.aLabelings = {}
    self.name = name

    lineage = dataset.getLineage()
    roots = map(operator.getitem, lineage, [-1] * len(lineage))

    map(self._addLabeling, roots)

  def _getLocalRef(self, target):
    return GlobalWrapper(target, glabeling=self)
    
  def _addLabeling(self, dataset):

    from compClust.mlx.views import BaseView

    if isinstance(dataset, BaseView):
      raise TypeError

    #
    # Create the new labeling
    
    lab = Labeling(dataset, nolink=1)

    #
    # Make sure this global labeling is in the dataset we just attatched to

    if self not in dataset.labelings:
      dataset.addLabeling(self)

    #
    # Put the labeling in our hash

    self.aLabelings[dataset] = lab

  def __removeLabeling(self, dataset):

    
    if self.aLabelings.has_key(dataset):

      self.aLabelings[dataset].detatch()
      del self.aLabelings[dataset]

    
  def __mapKeysOnPath(self, path, keys):

    #
    # Passes a bunch of keys along a path (root at end) and returns
    # the results
    #

    keylist = list(keys)

    for i in range(len(path) - 1):
      keylist = path[i]._mapKeysToParent(keylist, path[i+1])
      
    return keylist


  def _mapKeysFromView(self, view, keys):

    #
    # Here, we map the keys from the view to any/all of the datasets
    # we care about.
    #
    # returns a hash of root objects and their key sets
    #

    roots   = {}
    paths   = view.getLineage()

    for path in paths:

      root    = path[-1]
      keyset  = roots.get(root, [])
      keyset += self.__mapKeysOnPath(path, keys)
      roots[root] = unique(keyset)

    return roots


  def _mapKeysToView(self, path, keys):

    idx = range(len(path)-1)
    idx.reverse()
    for i in idx:
      keys = path[i]._mapKeysFromParent(keys, path[i+1])

    return keys


  def _fetchAnonymousLabeling(self, ds):
    """
    To track UIDs and labels we create anonymous labelings which reference
    a dataset, but the dataset does not know about them.
    """

    if not self.aLabelings.has_key(ds):
      self._addLabeling(ds)

    return self.aLabelings[ds]


  ############################################################################
  #
  # Labeling interface implementations
  #
    
  def sortDatasetByLabel(self, view):

    from compClust.mlx.views import SortedView

    if isinstance(view, SortedView):
      rows = view.getNumRows()
      tuples = zip(self.getLabelByRows(view), range(rows))
      tuples.sort()
      view.permuteRows(map(operator.getitem, tuples, [1] * rows))
    else:
      raise TypeError
    
  def getDataset(self, view=None):
    return view

  def writeLabels(self, view, stream=sys.stdout, delimiter="\t"):

    allKeys = view.getRowKeys() + view.getColKeys()
    for key in allKeys:
      labels = map(str, self.getLabelsByKey(view, key))
      stream.write(string.join(labels, delimiter))
      stream.write("\n")

  def labelRows(self, view, obj):
    
    labels = self._castLabels(obj)
    num_labels = len(labels)
    if num_labels!= view.getNumRows():
        error_msg = "The number of row labels [%d] does not match the number of rows [%d]" 
        error_msg %= (num_labels, view.getNumRows())
        raise LabelingDatasetLengthError(error_msg)
    map(self.addLabelsToRow, [view] * num_labels, labels, range(num_labels))


  def labelCols(self, view, obj):
    
    labels = self._castLabels(obj)
    num_labels = len(labels)
    if num_labels != view.getNumCols():
        error_msg = "The number of column labels [%d] does not match the number of columns [%d] for labeling %s" 
        error_msg %= (num_labels, view.getNumCols(), self.name)
        raise LabelingDatasetLengthError(error_msg)
    map(self.addLabelsToCol, [view] * num_labels, labels, range(num_labels))

    
  def addLabelToRow(self, view, label, row):
    self.addLabelToKey(view, label, view.getRowKey(row))


  def addLabelToCol(self, view, label, col):
    self.addLabelToKey(view, label, view.getColKey(col))


  def addLabelToKey(self, view, label, key):

    keyset = self._mapKeysFromView(view, [key])
    while len(keyset) > 0:
      (ds, keys) = keyset.popitem()
      labeling = self._fetchAnonymousLabeling(ds)
      labeling.addLabelToKeys(label, keys)


  def addLabelToRows(self, view, label, rowList):
    for row in rowList:
      self.addLabelToRow(view, label, row)


  def addLabelToCols(self, view, label, colList):
    for col in colList:
      self.addLabelToCol(view, label, col)


  def addLabelToKeys(self, view, label, keyList):
    for key in keyList:
      self.addLabelToKey(view, label, key)
      

  def addLabelsToRow(self, view, labels, row):
    for label in labels:
      self.addLabelToRow(view, label, row)


  def addLabelsToCol(self, view, labels, col):
    for label in labels:
      self.addLabelToCol(view, label, col)

      
  def addLabelsToKey(self, view, labels, key):
    for label in labels:
      self.addLabelToKey(view, label, key)

      
  def addLabelsToRows(self, view, labels, rows):
    for row in rows:
      self.addLabelsToRow(view, labels, row)


  def addLabelsToCols(self, view, labels, cols):
    for col in cols:
      self.addLabelsToCol(view, labels, col)

 
  def addLabelsToKeys(self, view, labels, keys):
    for key in keys:
      self.addLabelsToKey(view, labels, key)


  def getLabels(self, view=None):
    labels = []
    for l in self.aLabelings.values():
      labels += l.getLabels()
    return unique(labels)

  
  def getLabelsByRow(self, view, row):
    return self.getLabelsByKey(view, view.getRowKey(row))


  def getLabelsByCol(self, view, col):
    return self.getLabelsByKey(view, view.getColKey(col))

   
  def getLabelsByKey(self, view, key):
    """
    Map from the target view to the root dataset
    """

    labels = []
    keyset = self._mapKeysFromView(view, [key])

    while len(keyset) > 0:
      (ds, keys) = keyset.popitem()
      
      labeling = self._fetchAnonymousLabeling(ds)
      map(labels.extend, map(labeling.getLabelsByKey, keys))

    return labels


  def getLabelsByRows(self, view, rows):
    """
    Returns a list of lists of all labels from a set of rows.
    """
    return map(self.getLabelsByRow, [view] * len(rows), rows)


  def getLabelsByCols(self, view, cols):
    """
    Returns a list of lists of all labels from a set of columns.
    """
    return map(self.getLabelsByCol, [view] * len(cols), cols)


  def getLabelsByKeys(self, view, keys):
    """
    Returns a list of lists of all labels from a set of keys.
    """
    return map(self.getLabelsByKey, [view] * len(keys), keys)


  def getLabelByRow(self, view, row, n=0):
    try:
      label = self.getLabelsByRow(view, row)[n]
    except IndexError:
      label = None
    return label


  def getLabelByCol(self, view, col, n=0):
    try:
      label = self.getLabelsByCol(view, col)[n]
    except IndexError:
      label = None
    return label


  def getLabelByKey(self, view, key, n=0):
    try:
      label = self.getLabelsByKey(view, key)[n]
    except IndexError:
      label = None
    return label


  def getLabelByRows(self, view, rows=None, n=0):
    if rows is None:
      rows = range(view.getNumRows())

    num = len(rows)
    return map(self.getLabelByRow, [view] * num, rows, [n] * num)


  def getLabelByCols(self, view, cols=None, n=0):
    if cols is None:
      cols = range(view.getNumCols())

    num = len(cols)
    return map(self.getLabelByCol, [view] * num, cols, [n] * num)


  def getLabelByKeys(self, view, keys=None, n=0):
    if keys is None:
      keys = view.getRowKeys() + view.getColKeys()
      
    num = len(keys)
    return map(self.getLabelByKey, [view] * num, keys, [n] * num)


  def getRowsByLabel(self, view, label):

    rowKeys = view.getRowKeys()
    keys    = self.getKeysByLabel(view, label)
    return self._intersect(keys, rowKeys)

  
  def getColsByLabel(self, view, label):
    
    colKeys = view.getColKeys()
    keys    = self.getKeysByLabel(view, label)
    vKeys = self._intersect(colKeys, keys)
    return  map(colKeys.index, vKeys)
      

  def getKeysByLabel(self, view, label):

    keys = []
    lineage = view.getLineage()

    for path in lineage:
      ds     = path[-1]
      lab    = self._fetchAnonymousLabeling(ds)
      keyset = lab.getKeysByLabel(label)
      keys  += self._mapKeysToView(path, keyset)
      
    return unique(keys)


  def detatch(self):
    
    self.removeAll()
    map(self.__removeLabeling, self.aLabelings.keys())


  def removeAll(self):
    for l in self.aLabelings.values():
      l.removeAll()


  def removeLabel(self, label):
    for l in self.aLabelings.values():
      l.removeLabel(label)

      
  def removeLabelsFromRow(self, view, row):
    self.removeLabelsFromKey(view, view.getRowKey(row))


  def removeLabelsFromCol(self, view, col):
    self.removeLabelsFromKey(view, view.getColKey(col))

  
  def removeLabelsFromKey(self, view, key):
    
    keyset = self._mapKeysFromView(view, [key])
    while len(keyset) > 0:
      (ds, keys) = keyset.popitem()
      labeling = self._fetchAnonymousLabeling(ds)
      map(labeling.removeLabelsFromKey, keys)


  def removeLabelFromRow(self, view, label, row):
    self.removeLabelFromKey(view, label, view.getRowKey(row))


  def removeLabelFromCol(self, view, label, col):
    self.removeLabelFromKey(view, label, view.getColKey(col))

    
  def removeLabelFromKey(self, view, label, key):
    keyset = self._mapKeysFromView(view, [key])
    while len(keyset) > 0:
      (ds, keys) = keyset.popitem()
      labeling = self._fetchAnonymousLabeling(ds)
      map(labeling.removeLabelFromKey, [label] * len(keys), keys)


  def sizeof(self):
    return self._sizeoflabels(self.aLabelings)
  
  def _findCommonKeys(self, ds, label):
    return self.getKeysByLabel(ds, label)
        
  ###########
  # useful labeling type information        
  def isRowLabeling(self, view):
    """Does this labeling have a label for each row?
    """
    lineage = view.getLineage()
    is_row_labeling = True
    for path in lineage:
      ds     = path[-1]
      lab    = self._fetchAnonymousLabeling(ds)
      if not lab.isRowLabeling():
        return False
    return True
        
  def isColLabeling(self, view):
    """Does this labeling have a label for each column?
    """
    lineage = view.getLineage()
    is_row_labeling = True
    for path in lineage:
      ds     = path[-1]
      lab    = self._fetchAnonymousLabeling(ds)
      if not lab.isColLabeling():
        return False
    return True

  def isNumeric(self, view):
#    for l in self.getAllKeyLabels():
# JCR: changed to getLabels; only want to test all non-empty (unique) labels
    for l in self.getLabels():
      if type(l) not in (types.IntType, types.FloatType):
        return False
    return True
        
##############################################################################
#
# GlobalWrapper
#

class GlobalWrapper(Labeling):

  """
  Creates a pass though object which wraps around a global labeling and
  preserves a particular view context.
  """

  def __init__(self, view, name=None, glabeling=None):

    self.v = view
    if glabeling is None:
      self.g = GlobalLabeling(view)
    else:
      if isinstance(glabeling, GlobalLabeling):
        self.g = glabeling._getBaseRef()
      elif glabeling.__class__ == GlobalWrapper:
        raise ValueError("%s is a global wrapper" % (glabeling))
      elif glabeling.__class__ == Labeling:
        self.g = castToGlobalLabeling(view, glabeling)
      else:
        raise ValueError("%s was not a Labeling" % (str(glabeling)))

    if name is not None:
      self.setName(name)

  def __eq__(self, other):
    """Do a deep equality instead of just comparing object pointers
    """
    if not isinstance(other, GlobalWrapper):
      return False
    else:
      return self.v == other.v and self.g == other.g
    
    
  def _getBaseRef(self):
    return self.g
  
  def _getLocalRef(self, target):

    if self.v is not target:
      return None
    else:
      return self

  def getGlobalLabeling(self):
    """
    Returns the actual global labelings without the context provided by the
    wrapper.
    """
    return self.g
  
  def getName(self):
    return self.g.getName()

  def setName(self, name):
    self.g.setName(name)

  name = property(getName, setName, doc="Name property")
  
  def sortDatasetByLabel(self):
    self.g.sortDatasetByLabel(self.v)
  
  def getDataset(self):
    return self.g.getDataset(self.v)

  def writeLabels(self, stream=sys.stdout, delimiter="\t"):
    self.g.writeLabels(self.v, stream, delimiter)

  def labelRows(self, obj):
    self.g.labelRows(self.v, obj)

  def labelCols(self, obj):
    self.g.labelCols(self.v, obj)
    
  def _getUID(self, label):
    return self.g._getUID(self.v, label)

  def addLabelToRow(self, label, row):
    self.g.addLabelToRow(self.v, label, row)

  def addLabelToCol(self, label, col):
    self.g.addLabelToCol(self.v, label, col)

  def addLabelToKey(self, label, key):
    self.g.addLabelToKey(self.v, label, key)

  def addLabelToRows(self, label, rowList):
    self.g.addLabelToRows(self.v, label, rowList)

  def addLabelToCols(self, label, colList):
    self.g.addLabelToCols(self.v, label, colList)

  def addLabelToKeys(self, label, keyList):
    self.g.addLabelToKeys(self.v, label, keyList)

  def addLabelsToRow(self, labels, row):
    self.g.addLabelsToRow(self.v, labels, row)

  def addLabelsToCol(self, labels, col):
    self.g.addLabelsToCol(self.v, labels, col)

  def addLabelsToKey(self, labels, key):
    self.g.addLabelsToKey(self.v, labels, key)

  def addLabelsToRows(self, labels, rows):
    self.g.addLabelsToRows(self.v, labels, rows)

  def addLabelsToCols(self, labels, cols):
    self.g.addLabelsToCols(self.v, labels, cols)

  def addLabelsToKeys(self, labels, keys):
    self.g.addLabelsToKeys(self.v, labels, keys)

  def getLabels(self):
    return self.g.getLabels()

  def getAllRowLabels(self):
    """
    Returns a list of lists of the labels for all the rows.
    """
    return map(self.getLabelsByKey, self.v.getRowKeys())


  def getAllColLabels(self):
    """
    Returns a list of lists of the labels for all the columns.
    """
    return map(self.getLabelsByKey, self.v.getColKeys())

  
  def getAllKeyLabels(self):
    """
    Returns a list of lists of the labels for all the keys.
    """
    keys = self.v.getRowKeys() + self.v.getColKeys()
    return map(self.getLabelsByKey, keys)


  def getLabelsByRow(self, row):
    return self.g.getLabelsByRow(self.v, row)

  def getLabelsByCol(self, col):
    return self.g.getLabelsByCol(self.v, col)
  
  def getLabelsByKey(self, key):
    return self.g.getLabelsByKey(self.v, key)

  def getLabelsByRows(self, rows):
    return self.g.getLabelsByRows(self.v, cols)
  def getLabelsByCols(self, cols):
    return self.g.getLabelsByCols(self.v, cols)
  
  def getLabelsByKeys(self, keys):
    return self.g.getLabelsByKeys(self.v, keys)

  def getLabelByRow(self, row, n=0):
    return self.g.getLabelByRow(self.v, row, n)
  
  def getLabelByCol(self, col, n=0):
    return self.g.getLabelByCol(self.v, col, n)
  
  def getLabelByKey(self, key, n=0):
    return self.g.getLabelByKey(self.v, key, n)

  def getLabelByRows(self, rows=None, n=0):
    return self.g.getLabelByRows(self.v, rows, n)

  def getLabelByCols(self, cols=None, n=0):
    return self.g.getLabelByCols(self.v, rows, n)
  
  def getLabelByKeys(self, keys=None, n=0):
    return self.g.getLabelByKeys(self.v, keys, n)
  
  def getRowsByLabel(self, label):
    return self.g.getRowsByLabel(self.v, label)
  
  def getColsByLabel(self, label):
    return self.g.getColsByLabel(self.v, label)

  def getKeysByLabel(self, label):
    return self.g.getKeysByLabel(self.v, label)

  def detatch(self):
    self.g.detatch()

  def removeAll(self):
    self.g.removeAll()
      
  def removeLabel(self, label):
    self.g.removeLabel(label)
  
  def removeLabelsFromRow(self, row):
    self.g.removeLabelsFromRow(self.v, row)

  def removeLabelsFromCol(self, col):
    self.g.removeLabelsFromCol(self.v, col)
  
  def removeLabelsFromKey(self, key):
    self.g.removeLabelsFromKey(self.v, key)

  def removeLabelFromRow(self, label, row):
    self.g.removeLabelFromRow(self.v, label, row)

  def removeLabelFromCol(self, label, col):
    self.g.removeLabelFromCol(self.v, label, col)
    
  def removeLabelFromKey(self, label, key):
    self.g.removeLabelFromKey(self.v, label, col)

  def sizeof(self):
    return self.g.sizeof()

  def _findCommonKeys(self, ds, label):
    return self.g.getKeysByLabel(ds, label)
  
  ###########
  # useful labeling type information        
  def isRowLabeling(self):
    """Does this labeling have a label for each row?
    """
    return self.g.isRowLabeling(self.v)
        
  def isColLabeling(self):
    """Does this labeling have a label for each column?
    """
    return self.g.isColLabeling(self.v)
