#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the Labeling class.
"""

import unittest
import os

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling, ClusteredLabeling, LabelingDatasetLengthError, subsetByLabeling
from compClust.util.LoadExample import LoadCho


import compClust.mlx

class LabelingTestCases(unittest.TestCase):
  def setUp(self):
    """Construct a model to play with"""
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.__path__[0])
    
    data = [ [ 0, 0, 0, 0],
             [ 3, 3, 3, 3],
             [ 1, 1, 1, 1],
             [ 2, 2, 2, 2],
             [ 1, 1, 1, 1],
             [ 3, 3, 3, 3],
             [ 3, 3, 3, 3],
             [ 2, 2, 2, 2],
             [ 2, 2, 2, 2],
             [ 3, 3, 3, 3] ]
    # data = data * 10
    self.dataset = Dataset(data)

  def tearDown(self):
    os.chdir(self.original_dir)


  def test_constructor(self):
    labeling = Labeling(self.dataset)
    self.dataset.removeLabeling(labeling)

  def test_label_retrieval(self):
    labeling = Labeling(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])

    labs = labeling.getLabels()
    for lab in labs:
      if lab not in ['foo','r','c']:
        self.fail("Bad label")
    if len(labs) != 3:
      self.fail("Wrong number of labels")

    if labeling.getLabelsByRow(1) != ['foo']:
      self.fail()
    if labeling.getLabelsByRow(0) != ['r']:
      self.fail()
    if labeling.getLabelsByRow(6) != ['r']:
      self.fail() 
    if labeling.getLabelsByRow(2) != []:
      self.fail()

    if labeling.getLabelsByCol(0) != ['c']:
      self.fail()
    if labeling.getLabelsByCol(1) != ['c']:
      self.fail()
    if labeling.getLabelsByCol(2) != ['c']:
      self.fail() 
    if labeling.getLabelsByCol(3) != []:
      self.fail()

    if labeling.getRowsByLabel("r") != [0,5,6]:
      self.fail()
    if labeling.getColsByLabel("c") != [0,1,2]:
      self.fail()
    if labeling.getKeysByLabel("foo") != [1,3,4]:
      self.fail()

    tmp = labeling.getLabelByRows()
    if tmp != ['r', 'foo', None, 'foo', 'foo', 'r', 'r', None, None, None]:
      self.fail()
      
    self.dataset.removeLabeling(labeling)
    
  def test_name(self):

    from compClust.mlx.views import RowSubsetView
    
    rows = range(self.dataset.getNumRows())
    
    v = RowSubsetView(self.dataset, rows)
    l1 = Labeling(v, "foo")
    l2 = Labeling(v, "bar")
    l3 = Labeling(v, "baz")

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)

    l1.addLabelToRows("1",s1)
    l2.addLabelToRows("2",s2)
    l3.addLabelToRows("3",s3)

    l = v.getLabeling("foo")
    assert l.getLabels()[0] == "1"
    l = v.getLabeling("bar")
    assert l.getLabels()[0] == "2"
    l = v.getLabeling("baz")
    assert l.getLabels()[0] == "3"
    
  def test_sort(self):

    from compClust.mlx.views import SortedView
    
    rows = range(self.dataset.getNumRows())

    v = SortedView(self.dataset)
    labeling = Labeling(v)

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)
    s4 = filter(lambda x : x%4 == 3, rows)

    labeling.addLabelToRows("1",s1)
    labeling.addLabelToRows("2",s2)
    labeling.addLabelToRows("3",s3)
    labeling.addLabelToRows("4",s4)

    labeling.sortDatasetByLabel()
    tmp = map(labeling.getLabelByKey, self.dataset.getRowKeys())
    assert tmp == ['1']*3 + ['2']*3 + ['3']*2 + ['4']*2

  def test_remove_labels(self):
    
    labeling = Labeling(self.dataset)
    
    labeling.addLabelToRow('a',0)
    labeling.addLabelToRow('b',0)
    labeling.addLabelToRow('c',0)

    labeling.addLabelToCol('a',0)
    labeling.addLabelToCol('b',0)
    labeling.addLabelToCol('c',0)

    assert labeling.getLabelsByRow(0) == ['a','b','c']
    assert labeling.getLabelsByCol(0) == ['a','b','c']

    labeling.removeLabelFromRow('b', 0)
    assert labeling.getLabelsByRow(0) == ['a','c']

    labeling.removeLabel('a')
    assert labeling.getLabelsByRow(0) == ['c']
    assert labeling.getLabelsByCol(0) == ['b','c']

    labeling.removeLabelFromCol('c', 0)
    assert labeling.getLabelsByCol(0) == ['b']

    labeling.removeAll()
    assert labeling.getLabelsByRow(0) == []
    assert labeling.getLabelsByCol(0) == []
    
    self.dataset.removeLabeling(labeling)
    
  def test_create_labels(self):
    labeling = Labeling(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])
    
    self.dataset.removeLabeling(labeling)

  def test_addLabel(self):
    labeling = Labeling(self.dataset)

    keys = self.dataset.getKeys()

    even = filter(lambda x : keys.index(x) & 1 == 0, keys)
    odd  = filter(lambda x : keys.index(x) & 1 == 1, keys)

    self.failUnlessRaises(ValueError, labeling.addLabelsToKeys, "foo", even)
    labeling.addLabelToKeys("foo", even)
    labeling.addLabelToKeys("bar", odd)

    assert labeling.getKeysByLabel("foo") == even
    assert labeling.getKeysByLabel("bar") == odd
    
    self.dataset.removeLabeling(labeling)

  def test_is_row_or_col(self):
    """Check to make sure the is Row and is Col unique functions work
    """
    d1 = Dataset([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

    # test row labelings
    row_label = Labeling(d1)
    row_label.addLabelToRow("one",0)
    row_label.addLabelToRow("two", 1)
    row_label.addLabelToRow("four", 3)
    row_mixed_label = Labeling(d1)
    row_mixed_label.addLabelToRow("one",0)
    row_mixed_label.addLabelToRow("two", 1)
    row_mixed_label.addLabelToRow("four", 3)
    row_mixed_label.labelCols(["c1", "c2", "c3", ])

    self.failIf(row_mixed_label.isRowLabeling())
    self.failUnless(row_label.isRowLabeling())

    # test column labelings
    column_label = Labeling(d1)
    column_label.addLabelToCol("one", 0)
    column_label.addLabelToCol("two", 1)
    
    col_mixed_label = Labeling(d1)
    col_mixed_label.addLabelToCol("one", 0)
    col_mixed_label.addLabelToCol("two", 1)
    col_mixed_label.labelRows(["row"]*d1.numRows)

    self.failUnless(column_label.isColLabeling())
    self.failIf(col_mixed_label.isColLabeling())
    
  def test_is_numeric(self):
    """Check to make sure that isNumeric works"""
    d1 = Dataset([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

    # test row labelings
    str_label = Labeling(d1)
    str_label.labelRows(["r1", "r2", "r3", "r4"])
    self.failIf(str_label.isNumeric())
    
    mixed_label = Labeling(d1)
    mixed_label.labelRows([1, 2, 3.1, "4"])
    self.failIf(mixed_label.isNumeric())
    
    numeric_label = Labeling(d1)
    numeric_label.labelRows([1, 2, 3.1, 4])
    self.failUnless(numeric_label.isNumeric())
    
  def test_is_unique(self):
    """Make sure that the primary row and column labelings behave as expected

    They can only be labelings or none, and they must be unique in the axis specified.
    """
    d1 = Dataset([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

    # test row labelings
    row_label_duplicate = Labeling(d1)
    row_label_duplicate.labelRows(["row"] * d1.numRows)
    row_label_unique = Labeling(d1)
    row_label_unique.labelRows(["one", "two", "three", "four"])

    self.failIf(row_label_duplicate.isRowUnique())
    self.failUnless(row_label_unique.isRowUnique())

    # test column labelings
    column_label_duplicate = Labeling(d1)
    column_label_duplicate.labelCols(["col"]*d1.numCols)
    column_label_unique = Labeling(d1)
    column_label_unique.labelCols(["one", "two", "three"])

    self.failIf(column_label_duplicate.isColUnique())
    self.failUnless(column_label_unique.isColUnique())
    
  def test_clustered_labeling(self):
    from compClust.mlx.wrapper.DiagEM import DiagEM
    algorithm = DiagEM
    parameters = {'k': 3}
    cluster_labeling = ClusteredLabeling(self.dataset, algorithm, parameters)
    
  def testLabelingDatasetLengthError(self):
    """Test to make sure that labelRows and labelCols throw throw an error
    when they don't receive the right number of labels """
    d1 = Dataset([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

    # test row labelings
    row_label = Labeling(d1)
    self.failUnlessRaises(LabelingDatasetLengthError, row_label.labelRows, ["one", "two", "three", ])

    col_label = Labeling(d1)
    self.failUnlessRaises(LabelingDatasetLengthError, col_label.labelCols, ["one", "two", "three", "four"])
    
  def testSubsetByLabeling(self):
    # real data subsetByLabeling
    cho = LoadCho()
    diagem = cho.getLabeling('diagem clusters')
    self.failUnless(diagem is not None, "couldn't retrieve labeling")

    cho_diagem_5 = subsetByLabeling(cho, diagem, ['5'])
    self.failUnless(len(cho_diagem_5.getData()) == 0)
    cho.removeView(cho_diagem_5)
    
    cho_diagem_5 = subsetByLabeling(cho, diagem, '5')
    self.failUnless(len(cho_diagem_5.getData()) == 0)
    cho.removeView(cho_diagem_5)
    
    cho_diagem_5 = subsetByLabeling(cho, diagem, 5)
    self.failUnless(len(cho_diagem_5.getData()) > 0)
    
    # fake data subsetByLabeling
    zero_row = [0,0,0,0]
    one_row = [1,1,1,1]
    data = Dataset([zero_row]*5+[one_row]*5)
    zero_one = Labeling(data)
    zero_one.labelRows(['zero']*5+['one']*5)
    same = Labeling(data)
    same.labelRows(['same']*data.getNumRows())
    number = Labeling(data)
    number.labelRows(range(10))
    cols = Labeling(data)
    cols.labelCols([ 'c%d' % (x) for x in range(len(zero_row)) ])
        
    zero_ds = subsetByLabeling(data, zero_one, 'zero')
    self.failUnless(zero_ds.getNumRows() == 5)
    self.failUnless(zero_ds.getNumCols() == len(zero_row))
    for row in zero_ds.getData():
      self.failUnless(row == zero_row)
      
    one_ds = subsetByLabeling(data, zero_one, 'one')
    self.failUnless(one_ds.getNumRows() == 5)
    self.failUnless(one_ds.getNumCols() == len(one_row))
    for row in one_ds.getData():
      self.failUnless(row == one_row)
    
    empty_ds = subsetByLabeling(data, zero_one, 'bleem')
    self.failUnless(empty_ds.getNumRows() == 0)
    self.failUnless(empty_ds.getNumCols() == 0)
    
    copy_ds = subsetByLabeling(data, same, 'same')
    self.failUnless(copy_ds.getNumRows() == 10)
    self.failUnless(copy_ds.getNumCols() == len(zero_row))
    
    copy_by_list_ds = subsetByLabeling(data, same, ['same'])
    self.failUnless(copy_by_list_ds.getNumRows() == 10)
    self.failUnless(copy_by_list_ds.getNumCols() == len(zero_row))
    
    copy_by_zero_one_ds = subsetByLabeling(data, zero_one, ['zero','one'])
    self.failUnless(copy_by_zero_one_ds.getNumRows() == 10)
    self.failUnless(copy_by_zero_one_ds.getNumCols() == len(zero_row))
    
    num_ds = subsetByLabeling(data, number, 5)
    self.failUnless(num_ds.getNumRows() == 1)
    self.failUnless(num_ds.getNumCols() == len(one_row))
    self.failUnless(num_ds.getData() == one_row)

    # make sure that when we subset by just columns we get all the rows
    col_ds = subsetByLabeling(data, cols, 'c0')
    self.failUnless(col_ds.getNumRows() == data.getNumRows())
    self.failUnless(col_ds.getNumCols() == 1)
    
def suite(**kw):
  return unittest.makeSuite(LabelingTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  



