########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import Numeric

from compClust.mlx.views import SupersetView

##############################################################################
#
# RowSupersetView
#
# 'Stacks' dataset1 on top of dataset2.
#
##############################################################################

class RowSupersetView(SupersetView):
  """
  Concatenates the rows of two datasets.
  """
  
  def __init__(self, dataset1, dataset2, name=None):

    if dataset1.getNumCols() != dataset2.getNumCols():
      raise ValueError()

    SupersetView.__init__(self, dataset1, dataset2, name=name)

    #
    # Set the basic variables
    #

    self.numRows = self.ds1Rows + self.ds2Rows
    self.numCols = self.ds1Cols

    #
    # Make the key mapping
    #
    
    ds1keys = dataset1.getKeys()
    ds2keys = map(lambda x : x + self.ds1Rows, dataset2.getKeys())

    self.rowMap = ds1keys + ds2keys
    
  def getNumRows(self):
    if self.isDirty():
      self._refresh()
    return self.numRows

  def getNumCols(self):
    if self.isDirty():
      self._refresh()
    return self.numCols
    
  def getData(self, key=None):

    if self.isDirty():
      self._refresh()
      
    if key is None:
      return Numeric.concatenate((self.ds1.getData(),
                                  self.ds2.getData()))
    else:
      keyMax = self.getKeyMax()
      if key < 0 or key >= keyMax:
        raise ValueError()

      numRows = self.getNumRows()
      if key < numRows:
        if key < self.ds1Rows:
          return self.ds1.getData(key)
        else:
          return self.ds2.getData(key - self.ds1Rows)  
      else:
        col = key - numRows
        v1 = self.ds1.getColData(col)
        v2 = self.ds2.getColData(col)
        return Numeric.concatenate((v1,v2))

    return None

  def _mapKeysToParent(self, keys, parent):

    numRows = self.getNumRows()

    keyset = []
    for key in keys:
      if key >= numRows:
        keyset.append(key - self.ds2Rows)
      else:
        (ds, k) = self._mapKeyToParent(key)
        if ds is parent:
          keyset.append(k)

    return keyset
  
  def _mapKeyToParent(self, key):

    ds      = self.ds1
    numRows = self.getNumRows()
    
    if key >= numRows:
      newKey = key - self.ds2Rows
    else:
      newKey = self.rowMap[key]
      if newKey >= self.ds1Rows:
        ds = self.ds2
        newKey -= self.ds1Rows

    return (ds, newKey)


  def _mapKeyFromParent(self, key, parent):

    newKey = key

    if parent is self.ds2:
      newKey = key + self.ds1Rows
    if parent is self.ds1 and key >= self.ds1Rows:
      newKey = key + self.ds2Rows

    return newKey
