########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

from compClust.util.unique import unique
from compClust.mlx.views import AggregateFunctionView

class RowAggregateFunctionView(AggregateFunctionView):

  """
  A View which performs a function which aggregates all the data
  pointed to by a common label into a single row.  For example, to
  generate the mean trajectors of all rows in a dataset from a
  clustering labeling, you can use something like this:

  ds  is a Dataset
  lab is a clustering
  view = RowAggregateFunctionView(ds, lab, MLab.mean)

  """
  
  def __init__(self, dataset, labeling, function, name=None):
    """
    __init__(self, dataset, labeling, function)

    The initialization of the Aggregate function view requires a
    labeling and a function.  The function should accept a Numeric 2d
    array (note this may contain a 2d array with only one row which
    breaks things MLab.std) and return a Numeric 1d array,

    several MLab functions work out of the box:
       MLab.mean, MLab.max, MLab.min, MLab.sum, MLab.median

    of obvious intrest is to get the std. dev assciated with each
    aggregate.  This is tricky as the calculation fails when n = 1.
    So a simple safety wrapper function can be used (here 0 is
    returned when std's can't be calculated (some may want this to be
    NaNs.

    def std (array):
      try:
        r = MLab.std(array)
      except:
        r = array[0]*0
      return(r)
    
    """
    AggregateFunctionView.__init__(self, dataset, [], function, name=name)
    self.setLabeling(labeling)

  def setLabeling(self, labeling):

    self.__labeling = labeling
    self.dirty = 1
    self._makeDirtyChildren()
    self._refresh()
    
  def _refresh(self):
    
    if self.dataset.isDirty():
      self.dataset._refresh()

    labeling = self.__labeling
    keylist  = []

    if labeling is not None:
      rows = self.dataset.getNumRows()
      keys = map(labeling.getKeysByLabel, unique(labeling.getLabelByRows()))

      # The filter function is the only difference between the Row
      # and Column versions
      
      for keyset in keys:
        keylist.append(filter(lambda x : x < rows, keyset))

    self.setKeylist(keylist)
    
    self.dirty = 0 
