#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
A few simple test cases for the View schema objects
"""

import os
import unittest
import copy
import inspect

import MLab
import Numeric
import math

from compClust.mlx.datasets import Dataset
from compClust.mlx.views import *
from compClust.util import Debug

from compClust.mlx.labelings import Labeling

class ViewTestCases(unittest.TestCase):

  """
  Test Cases for the ColPCAView and RowPCAView
  """

  def setUp(self):

    datapath = os.path.split(inspect.getsourcefile(ViewTestCases))[0]
    self.linearColDataset = Dataset(os.path.join(datapath, 'simple2dLineDataset.txt'))
    self.linearRowDataset = Dataset(MLab.transpose(self.linearColDataset.getData()))

    self.colLinePCA = RowPCAView(self.linearColDataset)
    self.rowLinePCA = ColumnPCAView(self.linearRowDataset)
    self.tolerance = .005
    
  def testColRotationOfALine(self):
    """testing Column PCA rotation of a simple line"""
    fail = 0
    # All variance should be captured in the 1st PC
    # so make sure the second column is all the same
    datum0 = self.colLinePCA.getColData(1)[0]
    for datum in self.colLinePCA.getColData(1):
      if MLab.absolute(datum0 - datum) > self.tolerance:
        fail = 1
        break
      
    self.failUnless(not fail)

  def testRowRotationOfALine(self):
    """testing Row PCA rotation of a simple line"""
    fail = 0
    # All variance should be captured in the 1st PC
    # so make sure the second column is all the same
    datum0 = self.rowLinePCA.getRowData(1)[0]
    for datum in self.rowLinePCA.getRowData(1):
      if MLab.absolute(datum0 - datum) > self.tolerance:
        fail = 1
        break

    self.failUnless(not fail)

  def testCachedView(self):

    data = Numeric.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
    ds = Dataset(data)
    v1 = FunctionView(ds, math.log)
    v2 = CachedView(v1)

    assert Numeric.alltrue(Numeric.ravel(v1.getData() == v2.getData()))
    for i in range(ds.getNumRows() + ds.getNumCols()):
      assert Numeric.alltrue(Numeric.ravel(v1.getData(i) == v2.getData(i)))

  def testCaching(self):

    data = Numeric.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
    ds = Dataset(data)
    v1 = TransformView(ds, Numeric.array([[1,0,0],[0,1,0],[0,0,1]]))
    v2 = CachedView(v1)

    assert Numeric.alltrue(Numeric.ravel(v1.getData() == v2.getData()))
    for i in range(ds.getNumRows() + ds.getNumCols()):
      assert Numeric.alltrue(Numeric.ravel(v1.getData(i) == v2.getData(i)))

    v1.setMatrix(Numeric.array([[2,0,0],[0,3,0],[0,0,4]]))
    assert Numeric.alltrue(Numeric.ravel(v1.getData() == v2.getData()))
    for i in range(ds.getNumRows() + ds.getNumCols()):
      assert Numeric.alltrue(Numeric.ravel(v1.getData(i) == v2.getData(i)))
    
  def testTransposeView(self):
    
    data = Numeric.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
    ds = Dataset(data)
    v1 = TransposeView(ds)

    assert Numeric.alltrue(Numeric.ravel(v1.getData() == Numeric.transpose(data)))
    assert Numeric.alltrue(Numeric.ravel(v1.getData(4) == [4,5,6]))
    assert Numeric.alltrue(Numeric.ravel(v1.getData(1) == [2,5,8,11]))
    assert Numeric.alltrue(Numeric.ravel(v1.getData().shape == (3,4)))

  def testTransformView(self):

    data = Numeric.array([[1,2,3],[4,5,6],[7,8,9]])
    ds = Dataset(data)
    v1 = TransformView(ds, None)

    assert Numeric.alltrue(Numeric.ravel(v1.getData() == data))
    assert Numeric.alltrue(Numeric.ravel(v1.getData(1) == [4,5,6]))
    assert Numeric.alltrue(Numeric.ravel(v1.getData(4) == [2,5,8]))

    v1.setMatrix(Numeric.array([[2,0,0],[0,3,0],[0,0,4]]))

    assert Numeric.alltrue(Numeric.ravel(v1.getData(1) == [8,15,24]))
    assert Numeric.alltrue(Numeric.ravel(v1.getData(4) == [6,15,24]))
      
  def testFuncView(self):

    fail = 0
    f = FunctionView(self.linearRowDataset, lambda x : x - 1)

    diff = self.linearRowDataset.getData() - f.getData()
    sum  = MLab.sum(MLab.sum(diff))
    
    shape = self.linearRowDataset.getData().shape

    if sum != (shape[0] * shape[1]):
      fail = 1


    self.failUnless(not fail)

  def testSortedView(self):

    ds = Dataset([[1,2,3],[4,5,6],[7,8,9]])
    v  = SortedView(ds)

    assert Numeric.alltrue(Numeric.ravel(ds.getData() == v.getData()))

    v.permuteRows([1,2,0])
    assert Numeric.alltrue(Numeric.ravel(v.getData() == Numeric.array([[4,5,6],[7,8,9],[1,2,3]])))
    
    v.permuteCols([1,2,0])
    assert Numeric.alltrue(Numeric.ravel(v.getData() == Numeric.array([[5,6,4],[8,9,7],[2,3,1]])))
    
    assert Numeric.alltrue(Numeric.ravel(v.getData(0) == [5,6,4]))
    assert Numeric.alltrue(Numeric.ravel(v.getData(1) == [8,9,7]))
    assert Numeric.alltrue(Numeric.ravel(v.getData(2) == [2,3,1]))
    assert Numeric.alltrue(Numeric.ravel(v.getData(3) == [5,8,2]))
    assert Numeric.alltrue(Numeric.ravel(v.getData(4) == [6,9,3]))
    assert Numeric.alltrue(Numeric.ravel(v.getData(5) == [4,7,1]))

    v.sortRowsByFunction(lambda x : 1.0 / Numeric.sum(x))
    assert Numeric.alltrue(Numeric.ravel(v.getData() == [[8,9,7],[5,6,4],[2,3,1]]))
    v.sortRowsByFunction(Numeric.sum)
    assert Numeric.alltrue(Numeric.ravel(v.getData() == [[2,3,1],[5,6,4],[8,9,7]]        ))
    v.sortColsByFunction(lambda x : 1.0 / Numeric.sum(x))
    assert Numeric.alltrue(Numeric.ravel(v.getData() == [[3,2,1],[6,5,4],[9,8,7]]))
    v.sortColsByFunction(Numeric.sum)
    assert Numeric.alltrue(Numeric.ravel(v.getData() == ds.getData()))

    v.reset()
    assert Numeric.alltrue(Numeric.ravel(v.getData() == ds.getData()))
    
  def testName(self):

    f = FunctionView(self.linearRowDataset, lambda x : x - 1)
    g = FunctionView(self.linearColDataset, lambda x : x - 1)

    f.setName("foo")
    g.setName("bar")

    assert f.getData().shape == (2,100)
    assert g.getData().shape == (100,2)

  def testNonStringLabels(self):

    l = Labeling(self.colLinePCA)

    l.addLabelToRow(1, 2)
    l.addLabelToRow('x', 3)
    l.addLabelToRow((1,2,3), 4)
    l.addLabelToRow(1, 4)
    l.addLabelToRow('x', 2)

    l.addLabelToCol(0.5, 0)
    l.addLabelToCol(1, 1)

    assert l.getRowsByLabel(1) == [2,4]
    assert l.getRowsByLabel('x') == [2, 3]
    assert l.getRowsByLabel((1,2,3)) == [4]

    assert l.getKeysByLabel(1) != [2,4]
    assert l.getColsByLabel(1) == [1]
    assert l.getColsByLabel(0.5) == [0]

  def testAggregateFunctionView(self): 
    """Testing AggregateFunctionView"""
    def std (array):
      try:
        r = MLab.std(array)
      except:
        r = array[0]*0
      return(r)

    functions = [MLab.mean, MLab.min, MLab.max, MLab.sum, MLab.median, std]
    
    data = MLab.rand(10,5)
    ds = Dataset(data)

    lab = Labeling(ds)
    lab.labelRows(['a','a','b',1,1,'2','2','2','2',1])

    keylist = map(lab.getRowsByLabel, lab.getLabels())
    aView = AggregateFunctionView(ds, keylist, functions[0])

    
    for function in functions:
      #"\ttesting with function %s"%(function)
      aView.setFunction(function)
      correctData = []
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('a')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('b')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel(1)))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('2')))))

      correctData = Numeric.array(correctData)
      
      # "\t\tTesting Get Data"

      assert Numeric.alltrue(Numeric.ravel(aView.getData() == correctData))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(0) == correctData[0,:]))

      assert Numeric.alltrue(Numeric.ravel(aView.getColData(0) == MLab.transpose(correctData[:,0])))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(2) == correctData[2,:]))
      assert Numeric.alltrue(Numeric.ravel(aView.getColData(3) == MLab.transpose(correctData[:,3])))
      
      # "\t\tTesting Labeling (from parent->View)"
      # set up the labeling
      lab.addLabelToCol('foo', 1)
      lab.addLabelToCol('fee', 2)
      lab.addLabelToCol('fee', 3)
      
      # apply the labeling
      aLab = Labeling(aView)
      aLab.labelFrom(lab)

      # test the labeling
      assert aLab.getLabelByRows() == ['a', 'b', 1, '2']
      assert aLab.getLabelsByRow (0)    == ['a']
      assert aLab.getLabelsByRow (1)    == ['b']
      assert aLab.getLabelsByRow(2)     == [1]
      assert aLab.getLabelsByRow(3)     == ['2']
      assert aLab.getRowsByLabel('a')   == [0]
      
      # Just to make sure the col labelings are still working too
      assert aLab.getColsByLabel('foo') == [1]
      assert aLab.getLabelsByCol(1)     == ['foo']
      
      for col in aLab.getColsByLabel('fee') :
        assert col in [3,2]
      assert aLab.getLabelsByCol(2)     == ['fee']

      # "\t\tTesting Labeling (from View->Parent)"
      aLab2 = Labeling(aView)
      aLab2.labelRows([1,2,3,4])
      aLab2.addLabelToCol('foo', 1)

      lab2 = Labeling(ds)
      lab2.labelFrom(aLab2)
      assert lab2.getLabelByRows() == [1,1,2,3,3,4,4,4,4,3]
      assert lab2.getColsByLabel('foo') == [1]
    
  def testRowAggregateFunctionView(self):


    def std (array):
      try:
        r = MLab.std(array)
      except:
        r = array[0]*0
      return(r)

    functions = [MLab.mean, MLab.min, MLab.max, MLab.sum, MLab.median, std]
    data = MLab.rand(10,5)
    ds = Dataset(data)
    lab = Labeling(ds)
    lab.labelRows(['a','a','b',1,1,'2','2','2','2',1])
    aView = RowAggregateFunctionView(ds, lab, functions[0])
    # "\nTesting RowAggregateFunctionView"
    for function in functions:
      # "\ttesting with function %s"%(function)
      aView.setFunction(function)
      correctData = []
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('a')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('b')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel(1)))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getRowsByLabel('2')))))

      correctData = Numeric.array(correctData)
      
      # "\t\tTesting Get Data"
      assert Numeric.alltrue(Numeric.ravel(aView.getData() == correctData))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(0) == correctData[0,:]))
      assert Numeric.alltrue(Numeric.ravel(aView.getColData(0) == MLab.transpose(correctData[:,0])))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(2) == correctData[2,:]))
      assert Numeric.alltrue(Numeric.ravel(aView.getColData(3) == MLab.transpose(correctData[:,3])))
      
      # "\t\tTesting Labeling (from parent->View)"
      # set up the labeling
      lab.addLabelToCol('foo', 1)
      lab.addLabelToCol('fee', 2)
      lab.addLabelToCol('fee', 3)
      
      # apply the labeling

      aLab = Labeling(aView)
      aLab.labelFrom(lab)

      # test the labeling
      assert aLab.getLabelByRows() == ['a', 'b', 1, '2']
      assert aLab.getLabelsByRow (0)    == ['a']
      assert aLab.getLabelsByRow (1)    == ['b']
      assert aLab.getLabelsByRow(2)     == [1]
      assert aLab.getLabelsByRow(3)     == ['2']
      assert aLab.getRowsByLabel('a')   == [0]
      
      # Just to make sure the col labelings are still working too
      assert aLab.getColsByLabel('foo') == [1]
      assert aLab.getLabelsByCol(1)     == ['foo']
      
      for col in aLab.getColsByLabel('fee') :
        assert col in [3,2]
      assert aLab.getLabelsByCol(2)     == ['fee']

      # "\t\tTesting Labeling (from View->Parent)"
      aLab2 = Labeling(aView)
      aLab2.labelRows([1,2,3,4])
      aLab2.addLabelToCol('foo', 1)

      lab2 = Labeling(ds)
      lab2.labelFrom(aLab2)
      assert lab2.getLabelByRows() == [1,1,2,3,3,4,4,4,4,3]
      assert lab2.getColsByLabel('foo') == [1]

  def testColumnAggregateFunctionView(self):

    def std (array):
      try:
        r = MLab.std(array)
      except:
        r = array[0]*0
      return(r)

    # "\nTesting ColumnAggregateFunctionView"      
    functions = [MLab.mean, MLab.min, MLab.max, MLab.sum, MLab.median, std]
    data = MLab.rand(5,10)
    ds = Dataset(data)
    lab = Labeling(ds)
    lab.labelCols(['a','a','b',1,1,'2','2','2','2',1])
    aView = ColumnAggregateFunctionView(ds, lab, functions[0])
    data = Numeric.transpose(data)
    for function in functions:
      # "\ttesting with function %s"%(function)
      correctData = []
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getColsByLabel('a')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getColsByLabel('b')))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getColsByLabel(1)))))
      correctData.append(function(Numeric.take(data,
                                               tuple(lab.getColsByLabel('2')))))
      
      correctData = Numeric.transpose(Numeric.array(correctData))
      aView.setFunction(function)
      # "\t\tTesting Get Data"
      assert Numeric.alltrue(Numeric.ravel(aView.getData() == correctData))
      assert Numeric.alltrue(Numeric.ravel(aView.getColData(0) == correctData[:,0]))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(0) == MLab.transpose(correctData[0,:])))
      assert Numeric.alltrue(Numeric.ravel(aView.getColData(2) == correctData[:,2]))
      assert Numeric.alltrue(Numeric.ravel(aView.getRowData(3) == MLab.transpose(correctData[3,:])))
      
      # "\t\tTesting Labeling (from parent->View)"
      # set up the labeling
      lab.addLabelToRow('foo', 1)
      lab.addLabelToRow('fee', 2)
      lab.addLabelToRow('fee', 3)
      
      # apply the labeling

      aLab = Labeling(aView)
      aLab.labelFrom(lab)

      # test the labeling
      assert aLab.getLabelByCols() == ['a', 'b', 1, '2']
      assert aLab.getLabelsByCol (0)    == ['a']
      assert aLab.getLabelsByCol (1)    == ['b']
      assert aLab.getLabelsByCol(2)     == [1]
      assert aLab.getLabelsByCol(3)     == ['2']
      assert aLab.getColsByLabel('a')   == [0]
      
      # Just to make sure the row labelings are still working too
      assert aLab.getRowsByLabel('foo') == [1]
      assert aLab.getLabelsByRow(1)     == ['foo']
      
      for row in aLab.getRowsByLabel('fee') :
        assert row in [3,2]
      assert aLab.getLabelsByRow(2)     == ['fee']

      # "\t\tTesting Labeling (from View->Parent)"
      aLab2 = Labeling(aView)
      aLab2.labelCols([1,2,3,4])
      aLab2.addLabelToRow('foo', 1)
      lab2 = Labeling(ds)
      lab2.labelFrom(aLab2)
      assert lab2.getLabelByCols() == [1,1,2,3,3,4,4,4,4,3]
      assert lab2.getRowsByLabel('foo') == [1]

  def testRowFilteredView(self):
    
    # "\n Testing RowFilteredView"
  
    # set up a dataset and some labelings.
    data = MLab.rand(10,3)
    ds = Dataset(data)
    rowLab = Labeling(ds)
    rowLab.labelRows(range(ds.getNumRows()))
    colLab = Labeling(ds)
    colLab.labelRows(range(ds.getNumRows()))
  
    # create a new filtered View ...  no filter.
    fds = RowFilteredView(ds)
    rLab = Labeling(fds)
    rLab.labelFrom(rowLab)
    cLab = Labeling(fds)
    cLab.labelFrom(colLab)
    
    # "\t testing w/ a pass thru filter"
    # "\t\t Testing labelings Parent-->Child"
    assert rLab.getLabelByRows() == rowLab.getLabelByRows()
    assert cLab.getLabelByCols() == colLab.getLabelByCols()

    # "\t\t Testing getData"
    assert Numeric.alltrue(Numeric.ravel(ds.getData() == fds.getData()))

    # "\t testing with filtering rows"
    fds.setFilter(lambda ds, row: row%2==0)

    # "\t\t Testing labelings Parent-->View"
    assert rLab.getLabelByRows() == filter(lambda x: x%2 == 0, rowLab.getLabelByRows())
    assert cLab.getLabelByCols() == colLab.getLabelByCols()
    # "\t\t Testing getData"
    assert Numeric.alltrue(Numeric.ravel(fds.getData() == Numeric.take(data, filter(lambda x: x%2==0, range(len(data))))))

    l = Labeling(fds)
    l.labelRows(range(fds.getNumRows()))
    l2 = Labeling(ds)
    l2.labelFrom(l)
    lc = Labeling(fds)
    lc.labelCols(range(fds.getNumCols()))
    lc2 = Labeling(ds)
    lc2.labelFrom(lc)

    # "\t\t Testing labelings View-->Parent (with filter)"
    assert l2.getLabelByRows() == [0, None, 1, None, 2, None, 3, None, 4, None]
    assert lc2.getLabelByCols() == lc.getLabelByCols()
    fds.setFilter()
    # "\t\t Testing labelings View-->Parent (no filter)"
    assert l.getLabelByRows() == [0, None, 1, None, 2, None, 3, None, 4, None]
    assert lc2.getLabelByCols() == lc.getLabelByCols()
    
    # "\t Testing filtering a filtered View"
    fds.setFilter(lambda ds, row: row%2==0)
    ffds = RowFilteredView(fds)
    
    # "\t\t testing get Data (before and after function changes)"
    assert Numeric.alltrue(Numeric.ravel(
      ffds.getData() == fds.getData() == Numeric.take(data, filter(lambda x: x%2==0, range(len(data))))))
    ffds.setFilter(lambda ds, row: row%2==0)
    assert Numeric.alltrue(Numeric.ravel(
      ffds.getData() == Numeric.take(fds.getData(), filter(lambda x: x%2==0, range(len(fds.getData()))))))
    fds.setFilter()

    assert Numeric.alltrue(Numeric.ravel(
      ffds.getData() == Numeric.take(ds.getData(), filter(lambda x: x%2==0, range(len(ds.getData()))))))

    # "\t\t testing labelings View --> parent "
    fds.setFilter(lambda ds, row: row%2==0)
    ffds.setFilter(lambda ds, row: row%2==0)
    
    l = Labeling(ffds)
    l.labelRows(range(ffds.getNumRows()))
    l2 = Labeling(fds)
    l2.labelFrom(l)
    l3 = Labeling(ds)
    l3.labelFrom(l)
    assert l2.getLabelByRows() == [0, None, 1, None, 2]
    fds.setFilter()
    ffds.setFilter()
    assert l.getLabelByRows() == l2.getLabelByRows() == l3.getLabelByRows()
    # "\t\t testing labelings parent --> View"
    l = Labeling(ds)
    l.labelRows(range(ds.getNumRows()))
    l1 = Labeling(fds)
    l1.labelFrom(l)
    l2 = Labeling(ffds)
    l2.labelFrom(l)

    assert l.getLabelByRows() == l1.getLabelByRows() == l2.getLabelByRows() == range(ds.getNumRows())
    fds.setFilter(lambda ds, row: row%2==0)
    assert l1.getLabelByRows() == l2.getLabelByRows() == [0, 2, 4, 6, 8]
    ffds.setFilter(lambda ds, row: row%2==0)
    assert l2.getLabelByRows() == [0, 4, 8]

    # "\t\t testing labelings  parent <--- midView --> View"
    l = Labeling(fds)
    l.labelRows(range(fds.getNumRows()))
    l0 = Labeling(ds)
    l0.labelFrom(l)
    l2 = Labeling(ffds)
    l2.labelFrom(l)
    assert l2.getLabelByRows() == [0,2,4]
    assert l0.getLabelByRows() == [0, None, 1, None, 2, None, 3, None, 4, None]
    fds.setFilter()
    ffds.setFilter()

  def testGenericLabelingsActions(self):

    """ Perform some generic tests on any View which can be casted as
    a pass through view """

    # "\nPass Thru Labeling Test on:"
    viewFuncs = [lambda ds: BaseView(ds),
                 lambda ds: RowSubsetView(ds, range(ds.getNumRows())),
                 lambda ds: ColumnSubsetView(ds, range(ds.getNumCols())),
                 lambda ds: RowFilteredView(ds),
                 lambda ds: RowFunctionView(ds, lambda ds,row: ds.getRowData(row)),
                 lambda ds: ColumnFunctionView(ds, lambda y: y),
                 lambda ds: FunctionView(ds, lambda y:y),
                 lambda ds: SortedView(ds),
                 lambda ds: CachedView(ds) ]

    for func in viewFuncs:
      ## "\t %s"%(func(Dataset.PhantomDataset()))
      self.passThruLabelTest(func)

  def passThruLabelTest(self, func):

    """
    passThruLabelTest(self, func):

    func is a function which takes a dataset and returns a view:
    e.g:  view = func(dataset)

    An view which can be instantiated as a pass through function now
    can be tested using this function (ie. Row/Col subsets,
    filterviews, functionViews, etc.)

    This performs a fairly regirous test on the veiw passed to it
    provided that the view does not modify the data values
    (masking/subseting) is acceptible

    """
    shape = (50, 10)
    data = Numeric.reshape(Numeric.arange(shape[0]*shape[1]), shape)
    ds = Dataset(copy.copy(data))
    ds.setName('root DS')
    rLabs = range(shape[0])
    cLabs = range(shape[1])

    # crate a view hiearchy
    #
    #             ds
    #           /    \
    #         /        \
    #        vA         vB
    #     /  |   \      /
    #   vAA vAB vAC    vBA  
    #

    vA    = func(ds)
    vA.setName('vA')
    vB    = func(ds)
    vB.setName('vB')
    vAA   = func(vA)
    vAA.setName('vAA')
    vAB   = func(vA)
    vAB.setName('vAB')
    vAC   = func(vA)
    vAC.setName('vAC')
    vBA   = func(vB)
    vBA.setName('vBA')
    
    views = [vA, vB, vAA, vAB, vAC, vBA, ds]

    # now test to make sure we can label one view and have it
    # percolate to all the other views
    

    for v in views:
      rowLabs = Labeling(v, 'rowLabs')
      colLabs = Labeling(v, 'colLabs')
      rowLabs.labelRows(rLabs)
      colLabs.labelCols(cLabs)
      for targetView in views:
        l = Labeling(targetView, "%s from %s"%(rowLabs.getName(), v.getName()))
        l.labelFrom(rowLabs)
        
        l = Labeling(targetView, "%s from %s"%(colLabs.getName(), v.getName()))
        l.labelFrom(colLabs)

      for targetView in views:
        if targetView != v:
          rl = targetView.getLabeling(
            '%s from %s'%(rowLabs.getName(), v.getName()))
          cl = targetView.getLabeling(
            '%s from %s'%(colLabs.getName(), v.getName()))
          assert rl.getLabelByRows() == range(ds.getNumRows())
          assert cl.getLabelByCols() == range(ds.getNumCols())

  def testRowFunctionView(self):

    data = MLab.rand(10,3)
    ds = Dataset(data)
    v = RowFunctionView(ds, lambda ds, row: ds.getRowData(row)+1)
    correctData = data + 1

    # test if no key is provided
    assert Numeric.alltrue(Numeric.ravel(v.getData() == correctData))
    # test if a row key is provided
    assert v.getData(2) == correctData[2,:]
    # test if a col key is provided
    assert v.getData(12) == correctData[:,2]

    # test if setting the function does things right
    v.setFunction(None)
    assert Numeric.alltrue(Numeric.ravel(v.getData() == data))

    # test to make sure the number of cols is correct
    v.setFunction(lambda ds,row: ds.getRowData(row)[:2])
    assert v.getNumCols() == 2

  def testColumnFunctionView(self):

    data = MLab.rand(10,3)
    ds = Dataset(data)
    v = ColumnFunctionView(ds, lambda ds, col: ds.getColData(col)+1)
    correctData = data + 1

    # test if no key is provided
    assert v.getData() == correctData
    # test if a row key is provided
    assert v.getData(2) == correctData[2,:]
    # test if a col key is provided
    assert v.getData(12) == correctData[:,2]
        
def suite(**kw):
  return unittest.makeSuite(ViewTestCases)

if __name__ == "__main__":

  unittest.main(defaultTest="suite")


