########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import Numeric

from compClust.mlx.views import BaseView

###############################################################################
#
# SubsetView
#
# Gives Row/Column subset views of an underlying dataset
#
###############################################################################

class SubsetView(BaseView):
  """
  Provides a reduced view of a dataset

  The SubsetView class is derived from the BaseView class and overrides
  the getData() method to provide for dynamically changing the appearance
  of the underlying dataset.  This allows for datasets (which are immuatable)
  to appear to the user as mutable.
    """
  
  def __init__(self, dataset, keyset=[], name=None):

    BaseView.__init__(self, dataset, name=name)
    self.RKLUcache = None

    numRows = dataset.getNumRows()
    numCols = dataset.getNumCols()
    numTotal = numRows + numCols

    rowKeys, colKeys = dataset.splitRowColKeylist(keyset)
          
    self.base_rows = rowKeys
    self.base_cols = [ x-numRows for x in colKeys ]

    self.dirty = 1
    
  def _refresh(self):
    if self.dataset.isDirty():
      self.dataset._refresh()

    #
    # Derive the keysets from self.rows and self.cols

    numRows = self.dataset.getNumRows()
    numCols = self.dataset.getNumCols()
    
    self.rowKeys = [ x for x in self.base_rows if x < numRows ]
    self.numRows = len(self.rowKeys)

    self.colKeys = [ x+numRows for x in self.base_cols if x < numCols ]
    self.numCols = len(self.colKeys)

    self.RKLUcache = None
    self.dirty = 0

  def __extractRows(self, data):
    
    tmp = []
    for row in self.rowKeys:
      tmp.append(data[row])
    return Numeric.array(tmp)
  
  def __extractColumns(self, data):

    #
    # If a 1-D vector is passed in, it is assumed that this vector represents
    # a _row_ of data
    #

    if len(data.shape) == 1:
      cols = data.shape[0]
      axis = 0
    else:
      cols = data.shape[1]
      axis = 1

    rows    = self.dataset.getNumRows()
    columns = [ x-rows for x in self.colKeys ]
    data    = Numeric.take(data, columns, axis)

    return data

  def getData(self, key=None):

    if self.isDirty():
      self._refresh()

    numRows = self.getNumRows()
    numCols = self.getNumCols()

    if numRows == 0 or numCols == 0:

      data = Numeric.array([])
    
    elif key is None:

      #
      # Get the whole parental dataset
      #
      
      data = self.dataset.getData()
            
      if len(data.shape) == 1:
        data = Numeric.reshape(data, (1, data.shape[0]))

      #
      # Extract the rows and columns if need be
      #

      data = Numeric.array(self.__extractRows(data))
      data = Numeric.array(map(self.__extractColumns, data))

    else:

      keyMax = self.getKeyMax()
      if key < 0 or key >= keyMax:
        raise ValueError()

      data = self.dataset.getData(self._translateKey(key))

      if key < numRows:
        data = self.__extractColumns(data)
      else:
        data = self.__extractRows(data)

    return data
      
  def getNumRows(self):
    if self.isDirty():
      self._refresh()
      
    return self.numRows

  def getNumCols(self):
    if self.isDirty():
      self._refresh()

    return self.numCols

  def _mapKeysFromParent(self, keyList, parent=None):
    if self.isDirty():
      self._refresh()
      
    #
    # Use a lazy evaluation approach.  Only build the reverse mapping if it
    # is needed.  If any operation is performed which would invalidate
    # the reverse mapping (i.e. a sort), simply set self.RKLUcache to None
    # (RKLU == Reverse Key Look-Up)
    #

    if self.RKLUcache is None:
      self.RKLUcache = self._buildCache()

    RKLU = self.RKLUcache
    return map(RKLU.get, filter(RKLU.has_key, keyList))
    
  def _getKeysByUID(self, uid):

    if self.isDirty():
      self._refresh()

    return self._mapKeysFromParent(self.dataset._getKeysByUID(uid))


  def _getUIDsByKey(self, key=None):

    if self.isDirty():
      self._refresh()

    return self.dataset._getUIDsByKey(self._translateKey(key))


  def _removeUID(self, uid, key):

    if self.isDirty():
      self._refresh()

    self.dataset._removeUID(uid, self._translateKey(key))


  def _addUID(self, uid, key):

    if self.isDirty():
      self._refresh()

    self.dataset._addUID(uid, self._translateKey(key))


  def _getUID(self):
    return self.dataset._getUID()


  def _translateKey(self, key):

    if self.isDirty():
      self._refresh()

    keys = self.rowKeys + self.colKeys
    return keys[key]

  
  def _buildCache(self):

    if self.isDirty():
      self._refresh()

    keys = self.rowKeys + self.colKeys
    cache = {}
    for i in xrange(len(keys)):
      cache[keys[i]] = i
    return cache

  
  def _mapKeysToParent(self, keys, parent=None):

    if self.isDirty():
      self._refresh()
      
    return map(self._translateKey, keys)
    
