########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 30-May-2002, 11:00
#

from types import *

import copy
import operator
import sys
import Numeric
import MA

from compClust.util.unique import unique
from compClust.util.FileIO import readDelimitedFile
from compClust.util.FileIO import readDelimitedData
from compClust.mlx.interfaces import IDataset, ILabeling

class Dataset(IDataset):
  """
  Implementation of the IDataset interface.

  The Dataset class provides the basic mechanisms for wrapping a set
  of data vectors in an object framework.  Only basic operations are
  exposed to the user-level, though many helper functions must be
  implemented to provide for Labeling/View integration.
  """
  

  ###########################################################################
  #
  # Initialization and Reset
  #
  ###########################################################################
  
  def __init__(self, data, name=None, delimiter="\t"):

    #
    # cast the data into our internal representation
    #

    self.__castDataset(data, delimiter)
    self.setName(name)
    #
    # If the cast failed, throw an exception...
    #

    if self.data is None:
      raise ValueError()

    #
    # ...otherwise life continues as normal
    #

    self.numRows = self.data.shape[0]
    self.numCols = self.data.shape[1]
    self.key2lab = {}
    self.lab2key = {}
    self.counter = 0
    self.dirty   = 0

    self.resetVars()


  def resetVars(self):
    """
    Clears internal variables.
    
    Removes all references to labeling and view objects attached
    to this dataset.  The objects are detatch()ed to avoid dangling
    references.
    """
    self.__primaryRowLabeling = None
    self.__primaryColumnLabeling = None

    if hasattr(self, "labelings"):
      for label in self.labelings:
        label.detatch()
    self.labelings = []

    if hasattr(self, "views"):
      for view in self.views:
        view.detatch()
    self.views = []

    self.viewsByName = {}


  def getName(self):
    """
    Returns the name of the Dataset or View.
    """
    return self.__name


  def setName(self, name):
    """
    Sets the name of the Dataset or View.
    """
    self.__name = name

  name = property(getName, setName, doc="Set or return the datasets name")

  ##########################################################################
  #
  # Caching related functions
  #
  ##########################################################################

  def isDirty(self):
    """
    Determine if the data in the Dataset or View is dirty.

    isDirty() returns 1 if a parent of the View has changed (or potentially
    changed) its data.  A dirty views should recompute and data which depends
    on the data of its parent.
    """
    
    return self.dirty

  def _refresh(self):
    """
    Recomputes cached data and clear the dirty flag.
    """
    pass

  ##########################################################################
  #
  # Output the dataset to a stream
  #
  ########################################################################## 

  def writeDataset(self, stream=sys.stdout, delimiter="\t", rowLabeling=None):

    """
    Writes a dataset out to a stream.
    
    The optional argument labelRows allows for you to specify a
    row labeling to be prepended to each line in the data.  The row
    labeling must contain at least one label per row and only the first
    label is used. If you simply set rowLabeling to 1, then an ordinal count
    will be prepended to each line in the data. """


    numRows = self.getNumRows()

    # simple cast of a string to an open filehandle,
    if type(stream) == StringType:
      stream = open(stream, 'w')

    if rowLabeling is not None:
      try:
        labels = map(lambda x: rowLabeling.getLabelsByRow(x)[0],
                     range(0, numRows))
      except:
        labels = range(0, numRows)

    data = self.getData()
    for row in range(numRows):

      if rowLabeling is not None:
        stream.write("%s%s"%(str(labels[row]), delimiter))
      stream.write(str(data[row][0]))
      for col in range(1,self.getNumCols()):
        stream.write(delimiter)
        stream.write(str(data[row][col]))
      stream.write("\n")
      

  ##########################################################################
  #
  # Data Accessors
  #
  #########################################################################
  
  def getRowData(self, row):
    """
    Returns the data vector of a row of the dataset.  If the row is out of
    range, a ValueError() is raised.  The vector itself is a Masked Array.
    """

    return self.getData(self.getRowKey(row))


  def getColData(self, col):
    """
    Returns the data vector of a column of the dataset.  If the column is
    out of range, a ValueError() is raised. The vector itself is a Masked
    Array.
    """

    return self.getData(self.getColKey(col))

  
  def getData(self, key=None):
    """
    Returns the full dataset as a Numeric array.

    getData() will, by default, return the full dataset as a Numeric array.
    If *key* is not None, the vector with the specified key will be returned.
    If such a vector does not exists, a ValueError() is raised.  The data
    is returned in a Masked Array.
    """

    #
    # a key of 'None' will return the full dataset
    #
        
    if key is None:
      return self.data

    #
    # Otherwise the key must be a valid integer
    #
    
    else:
      keyMax = self.getKeyMax()
      numRows = self.getNumRows()
      
      if key < 0 or key >= keyMax:
        raise ValueError()

      if key < numRows:
        return self.__getRow(key)
      else:
        return self.__getCol(key - numRows)

    #
    # Should never get to this point
    #
    
    return None


  ##########################################################################
  #
  # Key accessors
  #
  ##########################################################################

  def getRowKey(self, row):
    """
    Returns the key for a given row of the dataset.

    If the row is out of range a ValueError() is raised.
    """

    if row < 0 or row >= self.getNumRows():
      raise ValueError()
    return row


  def getColKey(self, col):
    """
    Returns the key for a given column of the dataset.

    If the column is out of range a ValueError() is raised.
    """

    if col < 0 or col >= self.getNumCols():
      raise ValueError('Column key %d out of range:(0,%d]' % (col,self.getNumCols()))
    return col + self.getNumRows()


  def getRowKeys(self):
    """
    Returns the full set of row keys for the dataset.

    The keys will be returned in a list of length equal to the total number
    of rows and arranged such that the nth key in the list corresponds to the
    nth row of the dataset.  getRowKeys() is a wrapper around getKeys().
    """
    
    return self.getKeys()


  def getColKeys(self):
    """
    Returns the full set of column keys for the dataset.

    They will be returned in a list of length equal to the total number of
    columns and arranged such that the nth key in the list corresponds to the
    nth column of the dataset.  getColKeys() is a wrapper around getKeys().
    """
    
    return self.getKeys(1)

  
  def getKeys(self, axis=0):
    """
    Returns the valid keys for an axis in the order which they appear in the
    dataset.

    An axis equal to zero will return the row keys and any other value will
    return the column keys.
    """

    numRows = self.getNumRows()
    if axis == 0:
      return range(numRows)
    else:
      return range(numRows, numRows + self.getNumCols())

  def isRowKey(self, key):
    """Is a key a row key"""
    return 0 <= key < self.getNumRows()
  
  def isColKey(self, key):
    """Is a key a column key"""
    return self.getNumRows() <= key < self.getNumRows() + self.getNumCols()
  
  def splitRowColKeylist(self, keylist):
    """Split a list of keys into the component row and column key lists
    """
    numRows = self.getNumRows()
    numCols = self.getNumCols()
    numTotal = numRows + numCols
    rowKeys = []
    colKeys = []
    for key in keylist:
      
      if key < numRows: 
        rowKeys.append(key)
      elif numRows <= key < numTotal: 
        colKeys.append(key)
      else: 
        msg = 'Bad subset key value %d, %d by %d'
        msg %= (key, numRows,numCols)
        raise ValueError(msg)
    return (rowKeys, colKeys)
      
  #########################################################################
  #
  # View interfaces
  #
  #########################################################################

  def getView(self, name):
    """
    Returns the view with the given name, or None if it does not exit
    """
    return self.viewsByName.get(name, None)

  
  def getViews(self):
    """
    Return a list of all the views attatched to  this dataset.

    This may be conceived as a list of the child views from this node where
    the view structure is a tree with the base dataset at the root. In the
    case of superset views this structure breaks down, but is still valid for
    finding all the views associated with a dataset.
    """
    
    return self.views


  def addView(self, view):
    """
    Add a view to the view list of the dataset
    """

    if view is not None and view not in self.views:
      self.views.append(view)


  def addViewDefault(self, name, view_class, *args, **kwargs):
    """Either return the view with the specified name or create a new instance.

    The args and kwargs should be the standard parameters for the view_class
    provided.
    """
    view = self.viewsByName.get(name, None)
    if view is None:
      new_kwargs = copy.copy(kwargs)
      new_kwargs['name'] = name
      view = apply(view_class, args, new_kwargs)
    return view
    
  def removeView(self, view):
    """
    Removes a view from a dataset if it exists
    """
    
    if view in self.views:
      v = self.views.pop(self.views.index(view))
      v.detatch()
  

  #########################################################################
  #
  # Labeling interfaces.
  #
  ##########################################################################


  def addLabeling(self, labeling):
    """
    Add a labeling to the list of labelings in the dataset
    """

    if labeling is not None and labeling not in self.labelings:
      self.labelings.append(labeling)

      
  def removeLabeling(self, labeling):
    """
    Removes a labeling and all of its associated labels from the dataset.
    """

    #
    # Get the actual representation of the labeling
    
    real_labeling = labeling._getBaseRef()
  
    #
    # If it lives locally, remove it, otherwise try to remove it from
    # the roots
    
    if real_labeling in self.labelings:
      l = self.labelings.pop(self.labelings.index(real_labeling))
      l.detatch()
    else:
      roots = map(lambda x : x[-1], self.getLineage())[1:]
      for root in roots:
        root.removeLabeling(labeling)

      
  def getLabeling(self, name):
    """
    Returns the labeling with the given name, or None if it does not exist.
    If multiple labelings exist with the same name, the first one encountered
    is returned.
    """

    for labeling in self.getLabelings():
      if labeling.getName() == name:
        return labeling
    else:
      return None

  
  def getLabelings(self):
    """
    Returns a list of all the Labelings tied to this dataset.
    """

    #
    # First build a list of local labeling, plus the labelings in the
    # a root dataset(s)
    #

    raw = self.labelings
    for x in self.getLineage():
      raw = raw + x[-1].labelings
      
    #raw = unique(self.labelings + self.getLineage()[0][-1].labelings)
    raw = unique(raw)
        
    #
    # Now get references to the labelings relative to this dataset.  For
    # labelings local to the root dataset, this will return None
    #

    labelings = map(lambda x : x._getLocalRef(self), raw)

    #
    # Now filter the list and return it.
    #

    return filter(None, labelings)
  

  def getPrimaryRowLabeling(self):
    """Returns the primary labeling that describes the rows of the dataset
    """
    return self.__primaryRowLabeling

  def setPrimaryRowLabeling(self, l):
    """Returns the primary labeling that describes the rows of the dataset
    """
    if l is None:
      self.__primaryRowLabeling = None
    elif not isinstance(l, ILabeling):
      raise ValueError("%s:%s must be a labeling" % (str(type(l)), str(l)))
    elif not l.isRowUnique():
      raise ValueError('labeling %s is not row unique')
    self.__primaryRowLabeling = l

  primaryRowLabeling = property(getPrimaryRowLabeling, setPrimaryRowLabeling, doc="the primary row labeling is a unique identifier to help track rows in the dataset")

  def getPrimaryColumnLabeling(self):
    """Returns the primary labeling that describes the rows of the dataset
    """
    return self.__primaryColumnLabeling

  def setPrimaryColumnLabeling(self, l):
    """Returns the primary labeling that describes the rows of the dataset
    """
    if l is None:
      self.__primaryColumnLabeling = None
    elif not isinstance(l, ILabeling):
      raise ValueError("%s:%s must be a labeling" % (str(type(l)), str(l)))
    elif not l.isColUnique():
      raise ValueError('labeling %s is not row unique')
    self.__primaryColumnLabeling = l

  primaryColumnLabeling = property(getPrimaryColumnLabeling, setPrimaryColumnLabeling, doc="the primary column labeling is a unique identifier to help track columns in the dataset")

  ##########################################################################
  #
  # Meta-Information
  #
  # Returns information about the dataset
  #
  #########################################################################
  
  def getNumCols(self):
    """
    Returns the number of columns (dimensions or features) in the dataset.
    """
    
    return self.numCols


  def getNumRows(self):
    """
    Returns the number of rows (samples) in the dataset.
    """

    return self.numRows


  def getKeyMax(self):
    """
    Returns the largest valid key for the dataset.
    """

    return self.getNumRows() + self.getNumCols()
  

  def getNumAxis(self, axis=0):
    """
    Returns the number of elements along a particular axis
    """
    
    if axis == 0:
      return self.getNumRows()
    else:
      return self.getNumCols()


  ###########################################################################
  #
  # helper methods for doing drill-down and reverse-key lookup  
  #
  ###########################################################################

  def getLineage(self):
    """
    Returns a list of lists of all paths from this dataset to its root
    dataset(s).

    Because of supersets, there may be multiple base Dataset objects.
    """

    return [[self]]


  def _mapUIDToParent(self, uid, parent=None):
    """
    Returns the proper uid of the parent which corresponds the the uid given.
    """

    return uid


  def _mapKeysFromParent(self, keys, parent=None):
    """
    Converts a list of parent keys to the keys of self.

    It is possible that the set of keys returned may be smaller than the
    number of keys passed in.
    """

    return keys


  def _mapKeysToParent(self, keys, parent=None):
    """
    Converts a list of self keys to keys of the parent.

    It is possible that the set of keys returned may be larger than the
    number of keys passed in.
    """

    return keys
  
  ###########################################################################
  #
  # Private/Friend methods
  #
  # These methods are friend methods for Labelings/View and any other class
  # which may be coupled to the Dataset.  Any Dataset-like class must
  # implement analogs of these method to be able to tie into the
  # Labeling/View framework
  #
  ###########################################################################

  def _getKey2Lab(self):
    return self.key2lab

  def _getLab2Key(self):
    return self.lab2key
  
  def _getKeysByUID(self, uid):
    """
    Returns the set of keys that a UID marks.  If there is no such UID,
    an empty list is returned.
    """

    return self._getLab2Key().get(uid, [])


  def _getUIDsByKey(self, key=None):
    """
    Returns a list of UIDs attatched to a given key.

    If key is None then all the UIDs attatched to the dataset are returned.
    Duplicate IDs are removed before returning, though.  If the specified key
    does not exist, an empty list is returned
    """

    key2lab = self._getKey2Lab()
    
    #
    # If a key is specified, the operation is simple: Look up the list of
    # UID's in the hash
    #
    
    if key is not None:
      return key2lab.get(key, [])

    #
    # If not, we need to look at all of the UIDs and remove duplicates.  This
    # removal is accomplish by the use of a hash where each value equals its
    # key.  When duplicates are added, they take up no more space.  Then the
    # set of unique UIDs is just the values (or keys) of the hash.
    #

    else:
      return unique(reduce(operator.add, key2lab.values()))
        
  def _removeUID(self, uid, key):
    """
    Removes a particular UID from a vector with the given key.

    If the vector is not marked with the UID or the UID does not exist, no
    action is performed.
    """

    #
    # Wrap each lookup in a try/except block since there should be no error
    # if one tried to remove from a non-existent key.  For stricter error
    # checking, a KeyError() could be thrown in the except clause (or remove
    # the try/except all together.
    #

    try:
      v = self._getKey2Lab()[key]
      if uid in v:
        v.pop(v.index(uid))
      if len(v) == 0:
        del self._getKey2Lab()[key]
    except:
      pass

    try:
      v = self._getLab2Key()[uid]
      if key in v:
        v.pop(v.index(key))
      if len(v) == 0:
        del self._getLab2Key()[uid]
    except:
      pass

    
  def _addUID(self, uid, key):
    """
    _addUID(uid, key)

    Attatches a UID to a key of the dataset.  If the key is out of range, a
    ValueError() is raised.
    """

    #
    # Make sure the key is valid
    #

    keyMax = self.getKeyMax()
    if key < 0 or key >= keyMax:
      raise ValueError()

    #
    # Add this label to the key
    #

    key2lab = self._getKey2Lab()
    try:
      key2lab[key].append(uid)
    except:
      key2lab[key] = [uid]

    #
    # Add this key to the label
    #

    lab2key = self._getLab2Key()
    try:
      lab2key[uid].append(key)
    except:
      lab2key[uid] = [key]


  def _getUID(self):
    """
    Provides a counting service to the Labeling class.  The only requirement
    of this method is that it return a unique token each time it is called.
    """

    #
    # Use the simplest implementation.  Notice that self.counter needs to be
    # initialize to 0 in the construction and it is pre-incremented.  This
    # means that 0 is never returned as a valid UID.  This is important,
    # because a UID of zero is reserved for future use
    #
    
    self.counter += 1
    return self.counter
  
  ##########################################################################
  #
  # Private methods
  #
  ##########################################################################


  def __getRow(self, key):
    """
    Returns a Numeric vector corresponding to the specified row of the dataset.
    If the row is out of range, None is returned
    """
    
    r = None
    if key < self.data.shape[0]:
      r = self.data[key]
    return r


  def __getCol(self, key):
    """
    Returns a Numeric vector corresponding to the specified row of the dataset.
    If the row is out of range, None is returned
    """
    
    c = None
    if key < self.data.shape[1]:
      c = self.data[:,key]
    return c

  
  def __castDataset(self, obj, delimiter="\t"):
    """
    Initializes the dataset with the object.  Casting (conversion)
    rules are as follows:
    
    If obj is a          it is cast (converted) using
    ---------------      ----------------------------
    String               open a stream and use readDelimiteddelimitedFile()
    FileType (stream)    readDelimitedData()
    ListType             Numeric.array()
    TupleType            Numeric.array()
    Numeric.ArrayType    None
    MA.MaskedArray       None
    Instance(Dataset)    None
    
    Otherwise, None is returned.
    
    Note: read_dataset_delimited_stream() must be used instead of
    Note: Scientific.IO.ArrayIO.readArray(), since the stream cannot be
    Note: gaurenteed to be based on a file and thus have a filename.
    """
    
    t         = type(obj)
    self.data = None
    
    #
    # Simulate a switch statement
    #
    
    if (t == StringType):
      self.data = readDelimitedFile(obj, delimiter)
    
    # is it something that looks like a stream?
    elif (hasattr(obj, 'readline')):
      self.data = readDelimitedData(obj, delimiter)

    elif ((t == ListType) or (t == TupleType)):
      self.data = Numeric.array(obj, Numeric.Float)
      
    elif (t == Numeric.ArrayType):
      self.data = obj

    elif isinstance(obj, IDataset):
      self.data = obj.getData()

    elif isinstance(obj, MA.MaskedArray ):
      self.data = obj

    else:
      raise ValueError("Unrecognized dataset source %s" % (type(obj)))
        
    #
    # Check to see if the resulting dataset is a vector and reshape it if
    # it is.  Datasets _must_ always by arrays.
    #

    if len(MA.shape(self.data)) == 1:
      self.data = MA.reshape(self.data, (1, len(self.data)))


  def __str__(self):
    return('%s: %s, %d by %d'% \
           (self.__init__.im_class.__name__,  self.getName(), \
            self.getNumRows(), self.getNumCols()))
    
  def __repr__(self):
    return(self.__str__())

  def sizeof(self):
    labelings = self.getLabelings()
    size = Numeric.size(self.data) * self.data.itemsize()
    for l in labelings:
      size += l.sizeof()
    return size
