########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 30-May-2002, 11:00
#

from Dataset import Dataset
import Numeric

class ExtendedDataset(Dataset):

  def __init__(self, data):

    #
    # Do the normal Dataset initialization
    #
    
    Dataset.__init__(self, data)
    
    #
    # Now modify things
    #
    
    shape = self.data.shape
    if len(shape) == 2:
      self.data = Numeric.reshape(self.data, (1, shape[0], shape[1]))
     
    self.extdata = self.data
    self.data    = None
    
    self.numRows = self.extdata.shape[1]
    self.numCols = self.extdata.shape[2]
    
    #
    # The names are now a hash and set the locked layer to layer 0
    #
    
    self.layerName    = {}
    self._layer  = 0
    
  #
  # Override some Dataset methods with our own generalizations
  #

  #def getName(self):
  #  return self.getLayerName(self._layer)

  #def setName(self, name):
  #  return self.setLayerName(self._layer, name)

  def getData(self, key=None, layer=None):

    if layer is None:
      self.data = self.extdata[self.getLayer()]
    else:
      self.data = self.extdata[layer]

    return Dataset.getData(self, key)

  #
  # ExtDataset-specific methods
  #

  def setLayer(self, layer):

    maxlayer = self.getNumLayers() - 1

    if layer > maxlayer:
      raise IndexError("layer index out of range")
   
    self._layer = layer

  def setLayerByName(self, name):
    self.setLayer(self.getLayerNames().index(name))

  def getLayerNames(self):
    return map(self.getLayerName, range(self.getNumLayers()))

  def getNumLayers(self):
    return self.extdata.shape[0]
   
  def getLayerName(self, layer = None):
    if layer is None:
      layer = self._layer
    return self.layerName.get(layer)

  def setLayerName(self, layer, name):
    self.layerName[layer] = name

  def getLayer(self):
    return(self._layer)
