#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#       Authors: Lucas Scharenbroich
# Last Modified: 30-May-2002, 11:00
#

"""
Test suite for the Extended Dataset class.
"""

import unittest
import os

from compClust.mlx.datasets import ExtendedDataset
import compClust.mlx

import Numeric

class ExtDatasetTestCases(unittest.TestCase):
  def setUp(self):

    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.__path__[0])
    
  def tearDown(self):
    os.chdir(self.original_dir)

  def check_basics(self):
    
    a = Numeric.reshape(Numeric.array(range(27), Numeric.Float), (3,3,3))
    
    a[1] = a[1]%2
    a[2] = a[2] / 100.0

    ds = ExtendedDataset(a)
    ds.setName('ExtendedData')
    ds.setLayerName(0, 'Data Layer')
    ds.setLayerName(1, 'Mask Layer')
    ds.setLayerName(2, 'PValues')

    #
    # Should default to layer 0

    assert ds.getName() == 'ExtendedData'
    assert Numeric.alltrue(Numeric.ravel(ds.getData() == a[0]))

    ds.setLayer(0)
    assert ds.getLayerName(0) == 'Data Layer'
    assert Numeric.alltrue(Numeric.ravel(ds.getData() == a[0]))

    ds.setLayer(1)
    assert ds.getLayerName(1) == 'Mask Layer'
    assert Numeric.alltrue(Numeric.ravel(ds.getData() == a[1]))

    ds.setLayer(2)
    assert ds.getLayerName(2) == 'PValues'
    assert Numeric.alltrue(Numeric.ravel(ds.getData() == a[2]))

    try:
      ds.setLayer(3)
      fail("Set to illegal layer")
    except IndexError:
      pass
    
    assert ds.getLayerNames() == ['Data Layer','Mask Layer','PValues']

    ds.setLayerByName('Data Layer')
    assert ds.getLayerName() == 'Data Layer'

    ds.setLayerByName('Mask Layer')
    assert ds.getLayerName() == 'Mask Layer'
    
    ds.setLayerByName('PValues')
    assert ds.getLayerName() == 'PValues'
    
    try:
      ds.setLayerByName('foo')
      fail("Set to illegal layer")
    except ValueError:
      pass
    
def suite(**kw):
  suite = unittest.TestSuite()
  suite.addTest(ExtDatasetTestCases("check_basics"))
  return suite

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
