########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 29-May-2002, 11:00
#

"""
Enables labeling of arbitrary elements of a dataset

The trick is to leverage the existing UID infrastructure to make 2D mapping
available.

   * For each label, create a new UID for each element and add a
     said UID to the row/column intersect.

   * 

"""

import operator

from compClust.mlx.labelings import Labeling
from compClust.util.listOps import unique, cartesian, intersection

class ElementLabeling(Labeling):

  def __init__(self, data, name=None, nolink=0):
    Labeling.__init__(self, data, name, nolink)

  ############################################################################
  #
  # Override some critical Labeling methods
  #

  def getKeysByLabel(self, label):
    ds   = self.getDataset()
    uids = self.lab2uid.get(label, [])
    keys = map(ds._getKeysByUID, uids)

    return reduce(operator.concat, keys, [])


  def removeLabel(self, label):
    
    try:
      uids = self.lab2uid[label]
      ds   = self.getDataset()
      
      for uid in uids:
        keys = ds._getKeysByUID(uid)
        for key in keys:
          self.getDataset()._removeUID(uid, key)

      del self.lab2uid[label]
      for uid in uids:
        del self.uid2lab[uid]

    except:
      pass

    
  def removeLabelsFromKey(self, key):

    labels = self.getLabelsByKey(key)
    ds     = self.getDataset()
    
    for label in labels:
      uids = self.lab2uid[label]
      map(ds._removeUID, uids, [key] * len(uids))


  def removeLabelFromKey(self, label, key):

    labels = self.getLabelsByKey(key)
    ds     = self.getDataset()

    if label in labels:
      uids = self.lab2uid[label]
      map(ds._removeUID, uids, [key] * len(uids))

    
  def _getUID(self, label):

    uids = self.lab2uid.get(label, [])
    uid  = self.getDataset()._getUID()

    uids.append(uid)
    
    self.lab2uid[label] = uids
    self.uid2lab[uid]   = label

    return uid

  ############################################################################
  #
  # Extra ElementLabeling methods
  #  
  # Marking the dataset
  #
  ############################################################################
  
  def addLabelToElement(self, label, coord):

    #
    # Get a new UID for this element label
    #
    
    uid = self._getUID(label)

    #
    # Now label the row and column
    #
    
    ds = self.getDataset()

    ds._addUID(uid, ds.getRowKey(coord[0]))
    ds._addUID(uid, ds.getColKey(coord[1]))
    
  def addLabelToElements(self, label, coords):
    for coord in coords:
      self.addLabelToElement(label, coord)

  def addLabelsToElement(self, labels, coord):
    map(self.addLabelToElement, labels, [coord] * len(labels))

  def addLabelsToElements(self, labels, coords):
    map(self.addLabelsToElement, labels, coords)

  #############################################################################
  #
  # Information retrieval
  #
  #############################################################################
  
  def getLabelsByElement(self, coord):

    ds   = self.getDataset()
    
    rowUIDs = ds._getUIDsByKey(ds.getRowKey(coord[0]))
    colUIDs = ds._getUIDsByKey(ds.getColKey(coord[1]))

    #
    # Find the intersection of the UIDs
    #
    
    length = len(rowUIDs)

    a = {}
    map(operator.setitem, [a] * length, rowUIDs, [1] * length)
    uids = filter(a.get, colUIDs)

    uid2lab = self.uid2lab
    return map(uid2lab.get, filter(uid2lab.has_key, uids))

  def getLabelsByElements(self, coords):
    return map(self.getLabelsByElement, coords)

  def getLabelByElement(self, coord, n=0):
    try:
      label = self.getLabelsByElement(coord)[n]
    except IndexError:
      label = None
    return label

  def getLabelByElements(self, coords=None, n=0):
    if coords is None:
      ds = self.getDataset()
      rows = ds.getNumRows()
      cols = ds.getNumCols()
      coords = cartesian(rows, cols)
      
    return map(self.getLabelByElement, coords, [n]*len(coords))
  
  def getElementsByLabel(self, label):

    #
    # We need to match up UIDs to get ordering correct
    #

    elements = []
    ds = self.getDataset()
    
    rows = unique(self.getRowsByLabel(label))
    cols = unique(self.getColsByLabel(label))

    tmp = {}
    rowUIDs = map(ds._getUIDsByKey, rows)
    colUIDs = map(ds._getUIDsByKey, map(ds.getColKey, cols))

    #
    # Enter all the UIDs in a hash which maps to their position in the
    # list of row UIDs
    #

    for i in range(len(rows)):
      for uid in rowUIDs[i]:
        indices = tmp.get(uid, [])
        indices.append(i)
        tmp[uid] = indices

    #
    # Now scan the column UIDs and build up matched pairs
    #

    for i in range(len(cols)):
      for uid in colUIDs[i]:
        if not tmp.has_key(uid):
          continue
        
        rowvals = map(lambda x : rows[x], tmp[uid])
        elements += zip(rowvals, [cols[i]] * len(rowvals))
          
    return elements

  ############################################################################
  #
  # Label removal
  #
  ############################################################################
  
  def removeLabelsFromElement(self, coord):
    
    labels = self.getLabelsByElement(coord)
    for label in labels:
      self.removeLabelFromElement(label, coord)


  def removeLabelFromElement(self, label, coord):

    #
    # Like getElementByLabel we need to pin down the correct UID
    # for the row/column pair
    #

    ds = self.getDataset()

    rowkey  = ds.getRowKey(coord[0])
    colkey  = ds.getColKey(coord[1])
    
    rowUIDs = ds._getUIDsByKey(rowkey)
    colUIDs = ds._getUIDsByKey(colkey)

    uids = intersection(rowUIDs, colUIDs)
    uids = intersection(uids, self.lab2uid.get(label, []))

    for uid in uids:
      ds._removeUID(uid, rowkey)
      ds._removeUID(uid, colkey) 
    

