#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

import unittest
import os

from compClust.mlx.datasets import Dataset
from compClust.mlx.views import RowSubsetView
from compClust.mlx.labelings import GlobalElementLabeling
from compClust.mlx.labelings import GlobalElementWrapper

import compClust.mlx
import MLab

class GlobalElementLabelingTestCases(unittest.TestCase):

  def setUp(self):
    self.original_dir = os.getcwd()
    os.chdir(compClust.mlx.__path__[0])
    
  def tearDown(self):
    os.chdir(self.original_dir)

  def test_basics(self):

    ds = Dataset(MLab.rand(5,5))
    v  = RowSubsetView(ds, range(5))
    
    l1  = GlobalElementLabeling(ds, 'donner')
    l2  = GlobalElementWrapper(v, 'blitzen', l1)

    l1.addLabelToRow(ds, 'x', 1)
    l1.addLabelToRow(v, 'x', 2)
    l2.addLabelToRow('x', 3)

    assert l2.getRowsByLabel('x') == l1.getRowsByLabel(ds, 'x')

    l1.addLabelToCol(ds, 'y', 2)
    l1.addLabelToCol(v, 'y', 3)
    l2.addLabelToCol('y', 4)

    assert l1.getColsByLabel(ds, 'y') == l2.getColsByLabel('y')

    l1.addLabelToElement(ds, 'z', (0,0))
    l1.addLabelToElement(v, 'z', (4,4))
    l2.addLabelToElement('z', (2,2))
    
    assert l1.getElementsByLabel(ds, 'z') == l2.getElementsByLabel('z')

    assert l1.getLabelsByElement(ds, (0,0)) == ['z']
    assert l1.getLabelByElement(ds, (0,0)) == 'z'
    
    tmp = l1.getLabelsByElements(v)
    assert len(tmp) == 25
    assert tmp[0] == ['z']
    assert tmp[12] == ['z']
    assert tmp[24] == ['z']
    
    tmp = l1.getLabelByElements(v)
    assert len(tmp) == 25
    assert tmp[0] == 'z'
    assert tmp[12] == 'z'
    assert tmp[24] == 'z'

    l1.addLabelToElement(ds, 'z', (2,4))
    tmp = l2.getElementsByLabel('z')
    tmp.sort()
    assert tmp == [(0, 0), (2, 2), (2, 4), (4, 4)]
    
    l1.removeLabelFromElement(ds, 'z', (0,0))
    tmp = l2.getElementsByLabel('z')
    tmp.sort()
    assert tmp == [(2, 2), (2, 4), (4, 4)]
    
    l1.removeLabelFromElement(v, 'z', (2,2))
    tmp = l2.getElementsByLabel('z')
    tmp.sort()
    assert tmp == [(2, 4), (4, 4)]

    l1.removeLabelFromElement(v, 'z', (4,4))
    assert l2.getElementsByLabel('z') == [(2, 4)]

    l1.removeLabelFromElement(v, 'z', (2,4))
    assert l2.getElementsByLabel('z') == []
    
def suite(**kw):
  return unittest.makeSuite(GlobalElementLabelingTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
