########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import Numeric

from compClust.mlx.views import BaseView

###############################################################################
#
# RowFunctionView
#
# Applies a function per row of a dataset.  The function should accept a
# dataset and a row and return an N-element vector which is equal to the
# number of columns
#
# This class does not need to override _mapKeysToParent
#
###############################################################################

class RowFunctionView(BaseView):

  """
  Provides a view of a dataset passed throug a row-wise function. This
  function should expect a dataset and a row and and return an array of which
  is of numCols long.
  """

  def __init__(self, dataset, function, name=None):

    BaseView.__init__(self, dataset, name=name)
    self.setFunction(function)

  def setFunction(self, function):
    if function is None:
      function = lambda ds,row: ds.getRowData(row)
    self.func = function
    self.__setNumCols()
    self._makeDirtyChildren()

  def _refresh(self):
    if self.datasets.isDirty():
      self.dataset._refresh()
    self._setNumCols()
    self.dirty=0

  def __setNumCols(self):
    self.numCols = len(self.getData(0))

  def getNumCols(self):
    if self.isDirty():
      self._refresh()
    return(self.numCols)

  def getData(self, key=None):

    if self.isDirty():
      self._refresh()

    if self.func is not None:
      numRows = self.dataset.getNumRows()
      if key is None:
        data = map(self.func, [self.dataset]*numRows, range(numRows))
      elif key < numRows:
        data = self.func(self.dataset, key)
      elif key >= numRows and key < self.dataset.getKeyMax():
        data = Numeric.array(self.getData())[:,key-numRows]
      else:
        raise ValueError
        
      
    return Numeric.array(data)
