########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 29-Apr-2002, 11:00
#

"""
A Global Element Labeling is a pervasive Labeling available to all the
Datsets/Views in a Hierarchy.
"""

import string
import sys
import operator

from compClust.mlx.labelings import ElementLabeling
from compClust.mlx.labelings import GlobalLabeling
from compClust.mlx.labelings import GlobalWrapper
from compClust.util.listOps import cartesian, unique

#
# Multiply inherit from GlobalLabeling and ElementLabeling.  The order is
# important.  The GlobalElementLabeling will first look for method names
# in itself, then in GloablLabeling and the in ElementLabeling.
#

class GlobalElementLabeling(GlobalLabeling, ElementLabeling):

  def __init__(self, dataset, name=None):

    self.aLabelings = {}
    self.name = name

    lineage = dataset.getLineage()
    roots = map(operator.getitem, lineage, [-1] * len(lineage))

    map(self._addLabeling, roots)

  def _getLocalRef(self, target):
    return GlobalElementWrapper(target, gelabeling=self)
    
  def _addLabeling(self, dataset):

    from compClust.mlx.views import BaseView

    if isinstance(dataset, BaseView):
      raise TypeError

    #
    # Create the new labeling
    
    lab = ElementLabeling(dataset, nolink=1)

    #
    # Make sure this global labeling is in the dataset we just attatched to

    if self not in dataset.labelings:
      dataset.addLabeling(self)

    #
    # Put the labeling in our hash

    self.aLabelings[dataset] = lab

  ############################################################################
  #
  # We need two helper methods to map the elements to/from views
  #
  # These methods leverage the infrastructure in GlobalLabeling by simply
  # casting the element to keys, mapping keys and then joining the key sets
  # back into elements.
  #
  ############################################################################
  
  def _mapElementsFromView(self, view, elements):

    #
    # returns a hash of root datasets and their element sets which contain
    # these data in an element set
    #

    elementset = {}

    for element in elements:

      rowkey = view.getRowKey(element[0])
      colkey = view.getColKey(element[1])

      rowset = self._mapKeysFromView(view, [rowkey])
      colset = self._mapKeysFromView(view, [colkey])

      #
      # Take the cartesian products of the key sets and add then to
      # the element set for the view
      #

      for ds in rowset.keys():
        
        num      = ds.getNumRows()
        rowkeys  = rowset[ds]
        colkeys  = colset[ds]
        colkeys  = map(operator.sub, colkeys, [num] * len(colkeys))

        tmp      = elementset.get(ds, [])
        tmp     += cartesian(rowkeys, colkeys)
        elementset[ds] = tmp

    return elementset
          
  def _mapElementsToView(self, path, elements):

    #
    # Convert elements to keys on the first view, map the row and column
    # keys down the path, and then return the cartisian product at the end
    #

    elementset = []
    
    for element in elements:
      rowkey = path[-1].getRowKey(element[0])
      colkey = path[-1].getColKey(element[1])
      
      rowset = self._mapKeysToView(path, [rowkey])
      colset = self._mapKeysToView(path, [colkey])

      num    = path[0].getNumRows()
      colset = map(operator.sub, colset, [num] * len(colset))
      
      elementset += cartesian(rowset, colset)

    return elementset
    
  ############################################################################
  #
  # Extra ElementLabeling methods
  #  
  # Marking the dataset
  #
  ############################################################################

  def addLabelToElement(self, view, label, element):

    elementset = self._mapElementsFromView(view, [element])
    while len(elementset) > 0:
      (ds, elements) = elementset.popitem()
      labeling = self._fetchAnonymousLabeling(ds)
      labeling.addLabelToElements(label, elements)
      
  def addLabelToElements(self, view, label, elements):
    for element in elements:
      self.addLabelToElement(view, label, elements)

  def addLabelsToElement(self, view, labels, element):
    for label in labels:
      self.addLabelToElement(view, labels, element)

  def addLabelsToElements(self, view, labels, elements):
    for label in labels:
      self.addLabelToElements(view, labels, elements)
          
  #############################################################################
  #
  # Information retrieval
  #
  #############################################################################

  def getLabelsByElement(self, view, element):

    labels = []
    elementset = self._mapElementsFromView(view, [element])

    while len(elementset) > 0:
      (ds, elements) = elementset.popitem()
      
      labeling = self._fetchAnonymousLabeling(ds)
      map(labels.extend, map(labeling.getLabelsByElement, elements))

    return labels

  def getLabelsByElements(self, view, elements=None):

    if elements is None:
      rows = view.getNumRows()
      cols = view.getNumCols()
      elements = cartesian(range(rows), range(cols))
      
    return map(self.getLabelsByElement, [view] * len(elements), elements)

  def getLabelByElement(self, view, element, n=0):
    try:
      label = self.getLabelsByElement(view, element)[n]
    except IndexError:
      label = None
    return label

  def getLabelByElements(self, view, elements=None, n=0):

    labels = []
    
    if elements is None:
      rows = view.getNumRows()
      cols = view.getNumCols()
      elements = cartesian(range(rows), range(cols))

    for element in elements:
      labels.append(self.getLabelByElement(view, element, n))

    return labels
  
  def getElementsByLabel(self, view, label):

    elements = []
    lineage  = view.getLineage()

    for path in lineage:
      ds     = path[-1]
      lab    = self._fetchAnonymousLabeling(ds)
      tmp    = lab.getElementsByLabel(label)
      elements += map(tuple, self._mapElementsToView(path, tmp))

    return elements

  ############################################################################
  #
  # removing labels from elements
  #
  ############################################################################
  
  def removeLabelsFromElement(self, view, element):
    
    labels = self.getLabelsByElement(view, element)
    for label in labels:
      self.removeLabelFromElement(view, label, coord)


  def removeLabelFromElement(self, view, label, element):

    elementset = self._mapElementsFromView(view, [element])
    
    while len(elementset) > 0:
      (ds, elements) = elementset.popitem()      
      labeling = self._fetchAnonymousLabeling(ds)
      map(labeling.removeLabelFromElement, [label] * len(elements), elements)

##############################################################################
#
# GlobalElementWrapper
#
# Wrap around the GlobalElementLabeling
#
##############################################################################

class GlobalElementWrapper(GlobalWrapper):

  def __init__(self, view, name=None, glabeling=None):

    self.v = view
    if glabeling is None:
      self.g = GlobalElementLabeling(view)
    else:
      self.g = glabeling._getBaseRef()

    if name is not None:
      self.setName(name)

  def addLabelToElement(self, label, element):
    self.g.addLabelToElement(self.v, label, element)

  def getElementsByLabel(self, label):
    return self.g.getElementsByLabel(self.v, label)

  def getLabelsByElement(self, element):
    return self.g.getLabelsByElement(self.v, element)
  
  def getLabelsByElements(self, elements):
    return self.g.getLabelsByElements(self.v, elements)

  def getLabelByElement(self, element, n=0):
    return self.g.getLabelByElement(self.v, element, n)
  
  def getLabelByElements(self, elements=None, n=0):
    return self.g.getLabelByElements(self.v, elements, n)
  
  def getElementsByLabel(self, label):
    return self.g.getElementsByLabel(self.v, label)

  def removeLabelsFromElement(self, element):
    self.g.removeLabelsFromElement(self.v, element)

  def removeLabelFromElement(self, label, element):
    self.g.removeLabelFromElement(self.v, label, element)
