########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import Numeric

from compClust.mlx.views import BaseView

##############################################################################
#
# Transform View
#
# Applys an arbitrary matrix to the dataset.  The matrix should be a square
# NxN matrix where N is the number of rows in the parent dataset.
#
# This class does not need to override _mapKeysToParent
#
##############################################################################

class TransformView(BaseView):
  """
  Applies a transformation matrix to the dataset.  The usual rules of matrix
  algebra apply.
  """

  def __init__(self, dataset, matrix=None, name=None):

    BaseView.__init__(self, dataset, name=name)
    self.setMatrix(matrix)

  def setMatrix(self, matrix):

    if matrix is None:
      matrix = Numeric.identity(self.dataset.getNumCols())

    self.matrix = matrix  
    self.__setNumCols()
    self._makeDirtyChildren()

  def _refresh(self):
    if self.dataset.isDirty():
      self.dataset._refresh()

    self.__setNumCols()
    self.dirty = 0

  def __setNumCols(self):
    if self.matrix.shape[0] != self.dataset.getNumCols():
      raise ValueError, "Mismatched dimensions"
    self.numCols = self.matrix.shape[1]
    
  def getNumCols(self):

    if self.isDirty():
      self._refresh()
      
    return(self.numCols)
    
  def getData(self, key=None):

    if self.isDirty():
      self._refresh()

    v = None
    
    #
    # If we need the whole dataset, it's easy
    #
    
    if key is None:
      v = Numeric.dot(self.dataset.getData(), self.matrix)

    #
    # Otherwise this depend on if we are getting a row or column key
    #
    
    else:

      #
      # Raise an exception if the key is out of range
      #
      
      keyMax = self.getKeyMax()
      if key < 0 or key >= keyMax:
        raise ValueError()

      numRows = self.getNumRows()
      
      #
      # If it's a row key, multiply that row by the transform matrix
      #
      
      if key < numRows:
        v = Numeric.dot(self.dataset.getData(key), self.matrix)
      
      #
      # If it's a col key, multiply the dataset by that column of the
      # transform matrix
      #
      
      else:
        v = Numeric.dot(self.dataset.getData(), self.matrix[:, key - numRows])

    #
    # Return the data
    #
    
    return v
  
