#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the ElementLabeling class.
"""

import unittest
import os

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings.ElementLabeling import ElementLabeling
import compClust.mlx

class ElementLabelingTestCases(unittest.TestCase):
  def setUp(self):
    """Construct a model to play with"""
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.__path__[0])
    
    data = [ [ 0, 0, 0, 0],
             [ 3, 3, 3, 3],
             [ 1, 1, 1, 1],
             [ 2, 2, 2, 2],
             [ 1, 1, 1, 1],
             [ 3, 3, 3, 3],
             [ 3, 3, 3, 3],
             [ 2, 2, 2, 2],
             [ 2, 2, 2, 2],
             [ 3, 3, 3, 3] ]
    # data = data * 10
    self.dataset = Dataset(data)

  def tearDown(self):
    os.chdir(self.original_dir)

  def test_constructor(self):
    labeling = ElementLabeling(self.dataset)
    self.dataset.removeLabeling(labeling)

  def test_label_retrieval(self):
    labeling = ElementLabeling(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])

    labs = labeling.getLabels()
    for lab in labs:
      if lab not in ['foo','r','c']:
        self.fail("Bad label")
    if len(labs) != 3:
      self.fail("Wrong number of labels")

    if labeling.getLabelsByRow(1) != ['foo']:
      self.fail()
    if labeling.getLabelsByRow(0) != ['r']:
      self.fail()
    if labeling.getLabelsByRow(6) != ['r']:
      self.fail() 
    if labeling.getLabelsByRow(2) != []:
      self.fail()

    if labeling.getLabelsByCol(0) != ['c']:
      self.fail()
    if labeling.getLabelsByCol(1) != ['c']:
      self.fail()
    if labeling.getLabelsByCol(2) != ['c']:
      self.fail() 
    if labeling.getLabelsByCol(3) != []:
      self.fail()

    if labeling.getRowsByLabel("r") != [0,5,6]:
      self.fail()
    if labeling.getColsByLabel("c") != [0,1,2]:
      self.fail()
    if labeling.getKeysByLabel("foo") != [1,3,4]:
      self.fail()

    if labeling.getLabelByRows() != ['r', 'foo', None, 'foo', 'foo', 'r', 'r', None, None, None]:
      self.fail()
      
    self.dataset.removeLabeling(labeling)
    
  def test_name(self):

    from compClust.mlx.views import RowSubsetView
    
    rows = range(self.dataset.getNumRows())
    
    v = RowSubsetView(self.dataset, rows)
    l1 = ElementLabeling(v, "foo")
    l2 = ElementLabeling(v, "bar")
    l3 = ElementLabeling(v, "baz")

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)

    l1.addLabelToRows("1",s1)
    l2.addLabelToRows("2",s2)
    l3.addLabelToRows("3",s3)

    l = v.getLabeling("foo")
    assert l.getLabels()[0] == "1"
    l = v.getLabeling("bar")
    assert l.getLabels()[0] == "2"
    l = v.getLabeling("baz")
    assert l.getLabels()[0] == "3"
    
  def test_sort(self):

    from compClust.mlx.views import SortedView
    
    rows = range(self.dataset.getNumRows())

    v = SortedView(self.dataset)
    labeling = ElementLabeling(v)

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)
    s4 = filter(lambda x : x%4 == 3, rows)

    labeling.addLabelToRows("1",s1)
    labeling.addLabelToRows("2",s2)
    labeling.addLabelToRows("3",s3)
    labeling.addLabelToRows("4",s4)

    labeling.sortDatasetByLabel()
    assert labeling.getLabelByRows() == ['1']*3 + ['2']*3 + ['3']*2 + ['4']*2

  def test_remove_labels(self):
    
    labeling = ElementLabeling(self.dataset)
    
    labeling.addLabelToRow('a',0)
    labeling.addLabelToRow('b',0)
    labeling.addLabelToRow('c',0)

    labeling.addLabelToCol('a',0)
    labeling.addLabelToCol('b',0)
    labeling.addLabelToCol('c',0)

    assert labeling.getLabelsByRow(0) == ['a','b','c']
    assert labeling.getLabelsByCol(0) == ['a','b','c']

    labeling.removeLabelFromRow('b', 0)
    assert labeling.getLabelsByRow(0) == ['a','c']

    labeling.removeLabel('a')
    assert labeling.getLabelsByRow(0) == ['c']
    assert labeling.getLabelsByCol(0) == ['b','c']

    labeling.removeLabelFromCol('c', 0)
    assert labeling.getLabelsByCol(0) == ['b']

    labeling.removeAll()
    assert labeling.getLabelsByRow(0) == []
    assert labeling.getLabelsByCol(0) == []
    
    self.dataset.removeLabeling(labeling)
    
  def test_create_labels(self):
    labeling = ElementLabeling(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])
    
    self.dataset.removeLabeling(labeling)

  def test_elements(self):

    labeling = ElementLabeling(self.dataset)

    labeling.addLabelToElement('x', (4,3))

    assert labeling.getRowsByLabel('x') == [4]
    assert labeling.getColsByLabel('x') == [3]
    assert labeling.getElementsByLabel('x') == [(4,3)]

    labeling.addLabelToElements('y', [(0,0), (1,1), (2,1), (4,1)])
    assert labeling.getElementsByLabel('y') == [(0,0), (1,1), (2,1), (4,1)]
    assert labeling.getRowsByLabel('y') == [0,1,2,4]
    assert labeling.getColsByLabel('y') == [0,1,1,1]

    labeling.addLabelToElement('z',(1,1))
    assert labeling.getLabelsByElement((1,1)) == ['y','z']

    labeling.removeLabelFromElement('y', (1,1))
    assert labeling.getLabelsByElement((1,1)) == ['z']

    self.dataset.removeLabeling(labeling)

def suite(**kw):
  return unittest.makeSuite(ElementLabelingTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  



