########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
#                Christopher Hart
# Last Modified: Dec 13 23:41:29 PST 2001
#

import Numeric

from compClust.mlx.views import SupersetView

##############################################################################
#
# ColumnSupersetView
#
# Places dataset1 left of dataset2
#
##############################################################################

class ColumnSupersetView(SupersetView):
  """
  Concatenates the columns of two datasets.
  """
  
  def __init__(self, dataset1, dataset2):

    if dataset1.getNumRows() != dataset2.getNumRows():
      raise ValueError()

    SupersetView.__init__(self, dataset1, dataset2)

    self.numRows = self.ds1Rows
    self.numCols = self.ds1Cols + self.ds2Cols
    
    ds1keys = dataset1.getKeys(1)
    ds2keys = map(lambda x : x + self.ds1Cols, dataset2.getKeys(1))

    self.colMap = ds1keys + ds2keys

  def getNumRows(self):
    if self.isDirty():
      self._refresh()
    return self.numRows

  def getNumCols(self):
    if self.isDirty():
      self._refresh()
    return self.numCols
        
  def getData(self, key=None):

    if self.isDirty():
      self._refresh()
      
    if key is None:
      return Numeric.concatenate((self.ds1.getData(),
                                  self.ds2.getData()), 1)
    else:
      keyMax = self.getKeyMax()
      if key < 0 or key >= keyMax:
        raise ValueError()

      numRows = self.getNumRows()
      if key < numRows:
        v1 = self.ds1.getData(key)
        v2 = self.ds2.getData(key)
        return Numeric.concatenate((v1,v2))
      else:
        col = key - numRows
        if (col < self.ds1Cols):
          return self.ds1.getColData(col)
        else:
          return self.ds2.getColData(col - self.ds1Cols)
        
    return None

  def _mapKeysToParent(self, keys, parent):

    numRows = self.getNumRows()

    keyset = []
    for key in keys:
      if key >= numRows:
        (ds, k) = self._mapKeyToParent(key)
        if ds is parent:
          keyset.append(k)
      else:
        keyset.append(key)

    return keyset
  
  def _mapKeyToParent(self, key):

    ds      = self.ds1
    newKey  = key
    numRows = self.getNumRows()
    
    if key >= numRows:
      newKey = self.colMap[key - self.numRows]
      if newKey >= numRows + self.ds1Cols:
        ds = self.ds2
        newKey -= self.ds1Cols

    return (ds, newKey)


  def _mapKeyFromParent(self, key, parent):

    newKey = key

    if parent is self.ds2 and key >= self.getNumRows():
      newKey = key + self.ds1Cols

    return newKey
