########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import operator
import Numeric

from copy import copy

from compClust.util.unique import unique
from compClust.mlx.views import BaseView

##############################################################################
#
# AggregateFunctionView
#
# Takes in a list of key lists and a function.  The function is applied to the
# subset of rows/cols and replaces those rows/columns with its results.
#
# It is not allowed for an aggregate of keys to have both row and column
# keys.  i.e. for a 3x3 dataset, [[1,2],[3,4]] is allowed, but [[2,3]] is
# not.
#
##############################################################################

class AggregateFunctionView(BaseView):
  """
  This view takes in a list of lists and a function.  Each sublist defines
  a set of row or column keys to merge using the passed function.  Rows and
  columns which are not in any of the lists are passed through as is, though
  their relative order may have changed.
  """

  def __init__(self, dataset, keylist=[], function=None, name=None):

    BaseView.__init__(self, dataset, name=name)

    self.setFunction(function)
    self.setKeylist(keylist)
    self.__RKLUcache = None

    
  def setFunction(self, function):

    self.__function = function
    self._makeDirtyChildren()

  
  def getData(self, key=None):

    if self.isDirty():
      self._refresh()

    numRows  = self.getNumRows()
    numCols  = self.getNumCols()
    dataset  = self.dataset
    mapping  = self.__mapping
    function = self.__function
    singles  = self.__singles
    
    #
    # If there is no function defined, we force everything to zeros
    
    if function is None:
      function = lambda x : Numeric.zeros(x.shape[1])

    if key is None:

      """
      if len(mapping[0]) != 0 :
        # operate over the rows
        v = Numeric.array(map(self.getData, range(numRows)))
      else:
        # operate over the cols
        v = Numeric.array(map(self.getData, range(numRows, numRows+numCols)))
      """

      parentCols = dataset.getNumCols()
      
      #
      # Build a temporary dataset from the columns and then process the rows 

      tmpdata = []

      for key in self.getColKeys():

        keys = mapping[key]
        v    = Numeric.array(map(dataset.getData, keys))
        
        #
        # If this is a vector _and_ is not a single column marked by
        # the keylist, simply pass it through

        if v.shape[0] == 1 and singles.has_key(keys[0]):
          v = Numeric.reshape(v, (v.shape[1],))

        else:
          
          #
          # Apply the function

          v = function(v)

        #
        # Add the vector to the collection
        
        tmpdata.append(v)

      #
      # Construct a proper Numeric array
      
      tmpdata = Numeric.swapaxes(Numeric.array(tmpdata), 0, 1)

      #
      # Now the columns have been processed, apply the function to the rows
      
      v = []
      
      for key in self.getRowKeys():

        keys = mapping[key]
        
        #
        # Use the fact that rowKey == rowNumber

        tmp = []
        for key in keys:
          tmp.append(tmpdata[key])

        #
        # Create the dataset
        
        tmp = Numeric.array(tmp)
        
        if tmp.shape[0] == 1 and singles.has_key(keys[0]):
          tmp = Numeric.reshape(tmp, (tmp.shape[1],))
        else:
          tmp = function(tmp)

        #
        # add it to the end dataset

        v.append(tmp)
                
      #
      # Now reshape v to be a complient 2D dataset

      v = Numeric.reshape(v, (numRows, numCols))
      #v = Numeric.array(v)

    else:

      #
      # For the single-key case, we need to fetch the data defined by the
      # key (either row or column), and then apply the function along the
      # opposite axis

      if key < 0 or key > numRows + numCols:
        raise ValueEror

      #
      # Fetch the data

      keys = mapping[key]
      v = Numeric.array(map(dataset.getData, keys))

      #
      # Apply the function

      if v.shape[0] == 1 and singles.has_key(keys[0]):
        v = Numeric.reshape(v, (v.shape[1],))
      else:
        v = function(v)

      #
      # Get the set of indices to pull per key

      if key < numRows:
        rows = self.dataset.getNumRows()

        keys = self.getColKeys()
        keys = map(operator.getitem, [mapping] * len(keys), keys)
        vals = map(lambda x : map(operator.sub, x, [rows] * len(x)), keys)

      else:
    
        keys = self.getRowKeys()
        keys = map(operator.getitem, [mapping] * len(keys), keys)
        vals = keys

      #
      # Pull out the data (there should be a smarter way of doing this)

      tmp = []

      for i in tuple(range(len(keys))):
        set = vals[i]

        if len(set) == 1 and singles.has_key(keys[i][0]):
          tmp.append(v[set[0]])
        else:
          tmp.append(function(Numeric.take(v, set)))

      v = Numeric.array(tmp)
      
    return v

  def _refresh(self):

    if self.dataset.isDirty():
      self.dataset._refresh()

    sefl.setKeyList(self.__original_keylist)
    
    self.dirty = 0
    
  def setKeylist(self, keylist=[]):
    """
    Takes in a list of lists of keys which are the sets of rows and columns
    to group together and apply the function to.
    """

    self.__original_keylist = copy(keylist)
    
    #
    # To make the view perform sensibly, we take each of the keysets and
    # sort their list of keys.  Then, sort the full list be first entry
    # in each sublist.  This will be the relative ordering of the aggregates.
    #
    # Also, we will need to ensure that keysets do not contain both row and
    # column keys as that behavior is undefined
    #
    # Once we have a final, sorted keylist, add to it the list of keys which
    # do not appear in any of the keylists and then insert them into the
    # mapping in order.

    length     = len(keylist)
    parentRows = self.dataset.getNumRows()

    #
    # sort each sublist
    
    map(lambda x : x.sort(), keylist)

    #
    # Check for bad keys.  This is simple since the sublists are sorted,
    # just compare the first and last element, if one is less than the
    # number of rows and the other greater, then it's wrong
    #
    # filter out the good keysets
    #  1. Remove empty keysets
    #  2. Remove illegal keysets

    keylist     = filter(lambda x : x != [], keylist)
    filter_func = lambda x : (x[-1] < parentRows) or (x[0] >= parentRows)
    keylist     = filter(filter_func, keylist)

    #
    # Now determine which keys are not represented in the keylists

    totalKeys = parentRows + self.dataset.getNumCols()
    keys = {}

    #
    # first create a hash with all the keys in it with their values
    # set to 1
    
    map(operator.setitem, [keys] * totalKeys, range(totalKeys),
        [1] * totalKeys)

    #
    # Now run through all the keys in the keylist and set the entries in the
    # hash equal to zero

    for keyset in keylist:
      length = len(keyset)
      map(operator.setitem, [keys] * length, keyset, [0] * length)

    #
    # Finally filter out the total key list to those which still are marked
    # by a 1.  Build a hash from these keys so that we can quickly check
    # if a given key which is mapped to a single parent key was specified
    # to have a function applied to it or not.

    singleKeys = filter(keys.get, range(totalKeys))
    self.__singles = {}
    map(self.__singles.setdefault, singleKeys, [1] * len(singleKeys))

    #
    # Now make a small list out of each of these and add them to the keylist

    keylist += map(lambda x : [x], singleKeys)

    #
    # and sort the list
    
    keylist.sort()

    #
    # Now we can build our mapping to the parent dataset

    mapping = {}
    length  = len(keylist)
    map(operator.setitem, [mapping] * length, range(length), keylist)

    #
    # Finally determine how many rows and columns we have ended up with
    # by finding the first key which references a column in the parent dataset
    #
    # Do a quick binary search

    min = 0
    max = len(mapping) - 1
    
    while 1:
      if max < min:
        break
      m = (min + max) / 2
      if mapping[m][0] < parentRows:
        min = m + 1
      elif mapping[m][0] > parentRows:
        max = m - 1
      else:
        break

    #
    # m should point to the last row key

    while mapping[m][0] == parentRows:
      m -= 1

    self.numRows = m + 1
    self.numCols = len(mapping) - self.numRows

    #
    # Set out mapping

    self.__mapping = mapping

    #
    # We've changed so set the dirty flags

    self._makeDirtyChildren()

    #
    # Clear the caches as well

    self.__RKLUcache = None
      
  ####
  #
  # Helper methods & hook functions
  #
  ####

  def getNumRows(self):
    if self.isDirty():
      self._refresh()
    return self.numRows


  def getNumCols(self):
    if self.isDirty():
      self._refresh()
    return self.numCols

  
  def _addUID(self, uid, key):

    for parentKey in self.__mapping[key]:
      self.dataset._addUID(uid, parentKey)


  def _mapKeysToParent(self, keys, parent=None):

    mapping = self.__mapping
    return reduce(operator.concat, map(mapping.get,  keys),[])


  def _mapKeysFromParent(self, keylist, parent=None):

    if self.__RKLUcache is None:
      self.__RKLUcache = self._buildCache()

    RKLU = self.__RKLUcache

    #
    # Get the set of lists for these keys

    myKeys = map(RKLU.get, filter(RKLU.has_key, keylist))

    #
    # Reduce them to a single list and return the unique elements

    return unique(reduce(operator.concat, myKeys, []))


  def _getUIDsByKey(self, key=None):

    ds      = self.dataset
    mapping = self.__mapping

    #
    # If key is None, we retuan all the UIDs

    if key is None:
      keys = reduce(operator.concat, mapping.values())
    else:
      keys = mapping[key]
      
    uids = reduce(operator.concat, map(ds._getUIDsByKey, keys), [])
  
    return unique(uids)

  
  def _getKeysByUID(self, uid):
    return self._mapKeysFromParent(self.dataset._getKeysByUID(uid))


  def _buildCache(self):

    cache = {}
    numRows = self.getNumRows()
    numCols = self.getNumCols()
    
    #
    # construct a mapping from the parent to the child.  Each parent key
    # could be part of multiple child function eval, so this is a many
    # to many relation
    
    mapping = self.__mapping
    keys    = range(numRows + numCols)
    
    for key in keys:
      parentKeys = mapping[key]
      length     = len(parentKeys)
      
      #
      # Get the set of lists for these keys
      
      keylists = map(cache.get, parentKeys, [[]] * length)
      
      #
      # Add the current key to them all and add the lists back to the cache
      
      newKeys = map(operator.concat, keylists, [[key]] * length)
      map(cache.setdefault, parentKeys, newKeys)
      
      #
      # Add them back to the cache
      
      map(cache.setdefault, parentKeys, keylists)
      
    return cache
