#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Test suite for the Global Labeling class.
"""

import unittest
import os

if not hasattr(__builtins__, 'frozenset'):
  from sets import ImmutableSet as frozenset
  
from compClust.mlx.datasets import Dataset
from compClust.mlx.views import *
from compClust.mlx.labelings.GlobalLabeling import *

import compClust.mlx
import MLab

class GlobalLabelingTestCases(unittest.TestCase):
  def setUp(self):
    """Construct a model to play with"""
    self.original_dir    = os.getcwd()
    os.chdir(compClust.mlx.__path__[0])
    
    data = [ [ 0, 0, 0, 0],
             [ 3, 3, 3, 3],
             [ 1, 1, 1, 1],
             [ 2, 2, 2, 2],
             [ 1, 1, 1, 1],
             [ 3, 3, 3, 3],
             [ 3, 3, 3, 3],
             [ 2, 2, 2, 2],
             [ 2, 2, 2, 2],
             [ 3, 3, 3, 3] ]
    # data = data * 10
    self.dataset = Dataset(data)

  def tearDown(self):
    os.chdir(self.original_dir)

  def testCast(self):

    cv = CachedView(self.dataset)
    l = Labeling(cv)
    l.labelRows([1,2,2,3,4,3,5,5,'a','a'])

    gl = castToGlobalWrapper(l, 'gFoo', removeLocal=False)
    gl2 = self.dataset.getLabelings()[0]

    self.failUnless(gl2.getLabelByRows() == l.getLabelByRows(),
                    "gl2 != l, %s, %s" % (gl2.getLabelByRows, l.getLabelByRows))

  def testWithAggregate(self):

    cv = CachedView(self.dataset)
    gl = GlobalWrapper(cv)
    l  = Labeling(cv)

    l.labelRows([1,2,2,3,4,3,5,5,'a','a'])
    gl.labelRows([1,2,2,3,4,3,5,5,'a','a'])

    v1 = RowAggregateFunctionView(cv, gl, MLab.mean)
    v2 = RowAggregateFunctionView(cv, l, MLab.mean)

    assert v1.getData() == v2.getData()
    
  def testLabelUsing(self):

    cv = CachedView(self.dataset)
    gl = GlobalWrapper(self.dataset)

    gl.labelRows(['a','b','c','d','e','f','g','h','i','j'])
    l = Labeling(cv)
    l.labelFrom(gl)
    
    assert gl.getLabelByRows() == l.getLabelByRows()
    
  def simple(self):

    cv1 = CachedView(self.dataset)
    cv2 = CachedView(self.dataset)

    gl = GlobalLabeling(self.dataset)

    gl.addLabelToRow(cv1, 'x', 5)

    assert gl.getRowsByLabel(cv2, 'x') == [5]
    
    self.dataset.removeLabeling(gl)
    self.dataset.removeView(cv1)
    self.dataset.removeView(cv2)

  def medium(self):

    #
    # Test this config:
    #
    #    A   B
    #   / \ / \
    #  C   D   E
    #      
    #      ^
    #      |
    #      +------ create global labeling here
    #

    ds1 = Dataset(MLab.rand(10,5))
    ds2 = Dataset(MLab.rand(10,5))

    ds1.setName("ds1")
    ds2.setName("ds2")

    ss = ColumnSupersetView(ds1, ds2)
    cv1 = CachedView(ds1)
    cv2 = CachedView(ds2)
    sv = SortedView(cv2)
    
    cv1.setName("cv1")
    cv2.setName("cv2")
    sv.setName("sv")
    
    gl = GlobalLabeling(ss, 'foobar')
    gw = GlobalWrapper(ss, glabeling=gl)
    
    gw.addLabelToRow('x', 2)

    assert gl.getRowsByLabel(ds1, 'x') == [2]
    assert gl.getRowsByLabel(ds2, 'x') == [2]
    assert gl.getRowsByLabel(cv1, 'x') == [2]
    assert gl.getRowsByLabel(cv1, 'x') == [2]
    assert gl.getRowsByLabel(ss, 'x')  == [2]

    sv.permuteRows([9,8,7,6,5,4,3,2,1,0])
    assert gl.getRowsByLabel(sv, 'x') == [7]

    gw = cv1.getLabelings()[0]
    gw.addLabelToRow('y', 3)

    assert gl.getRowsByLabel(cv2, 'y') == []
    assert gl.getRowsByLabel(ss, 'y') == [3]
    
    ss.removeLabeling(gl)

    cv2.removeView(sv)
    ds1.removeView(cv1)
    ds2.removeView(cv2)
    ds1.removeView(ss)
    ds2.removeView(ss)

  def medium2(self):

    ds = Dataset(MLab.rand(10,3))

    ss1 = RowSubsetView(ds, [1,2,3,5,8,9])
    ss2 = RowSubsetView(ds, [0,2,4,6,7,8])

    gw = GlobalWrapper(ss1)
    gw.addLabelToRows('x', [0,1,2,3,4,5])

    l = ss2.getLabelings()[0].getRowsByLabel('x')
    l.sort()
    assert l == [1,5]

    ds.removeView(ss1)
    ds.removeView(ss2)
    
  def test_constructor(self):
    labeling = GlobalWrapper(self.dataset)
    self.dataset.removeLabeling(labeling)

  def test_label_retrieval(self):
    labeling = GlobalWrapper(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])

    labs = labeling.getLabels()

    for lab in labs:
      if lab not in ['foo','r','c']:
        self.fail("Bad label")
    if len(labs) != 3:
      self.fail("Wrong number of labels")

    if labeling.getLabelsByRow(1) != ['foo']:
      self.fail("foo, row 1")
    if labeling.getLabelsByRow(0) != ['r']:
      self.fail("r, row 0")
    if labeling.getLabelsByRow(6) != ['r']:
      self.fail("r, row 6") 
    if labeling.getLabelsByRow(2) != []:
      self.fail("empty, row 2")

    if labeling.getLabelsByCol(0) != ['c']:
      self.fail("c, col 0")
    if labeling.getLabelsByCol(1) != ['c']:
      self.fail("c, col 1")
    if labeling.getLabelsByCol(2) != ['c']:
      self.fail("c, col 2") 
    if labeling.getLabelsByCol(3) != []:
      self.fail("empty, col 3")

    tmp = labeling.getRowsByLabel("r")
    tmp.sort()
    if tmp != [0,5,6]:
      self.fail("label r")
      
    tmp = labeling.getColsByLabel("c")
    tmp.sort()  
    if tmp != [0,1,2]:
      self.fail("label c")

    tmp = labeling.getKeysByLabel("foo")
    tmp.sort()
    if tmp != [1,3,4]:
      self.fail("label foo")

    if labeling.getLabelByRows() != ['r', 'foo', None, 'foo', 'foo', 'r', 'r', None, None, None]:
      self.fail()
      
    self.dataset.removeLabeling(labeling)
    
  def test_name(self):

    rows = range(self.dataset.getNumRows())
    
    v = RowSubsetView(self.dataset, rows)
    l1 = GlobalWrapper(v)
    l1.setName("foo")
    
    l2 = GlobalWrapper(v)
    l2.setName("bar")
    
    l3 = GlobalWrapper(v)
    l3.setName("baz")

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)

    l1.addLabelToRows("1",s1)
    l2.addLabelToRows("2",s2)
    l3.addLabelToRows("3",s3)

    l = v.getLabeling("foo")
    assert l.getLabels()[0] == "1"
    l = v.getLabeling("bar")
    assert l.getLabels()[0] == "2"
    l = v.getLabeling("baz")
    assert l.getLabels()[0] == "3"
    
  def test_sort(self):

    rows = range(self.dataset.getNumRows())

    v = SortedView(self.dataset)
    labeling = GlobalWrapper(v)

    s1 = filter(lambda x : x%4 == 0, rows)
    s2 = filter(lambda x : x%4 == 1, rows)
    s3 = filter(lambda x : x%4 == 2, rows)
    s4 = filter(lambda x : x%4 == 3, rows)

    labeling.addLabelToRows("1",s1)
    labeling.addLabelToRows("2",s2)
    labeling.addLabelToRows("3",s3)
    labeling.addLabelToRows("4",s4)

    labeling.sortDatasetByLabel()
    assert labeling.getLabelByRows() == ['1']*3 + ['2']*3 + ['3']*2 + ['4']*2

  def test_remove_labels(self):
    
    labeling = GlobalWrapper(self.dataset)
    
    labeling.addLabelToRow('a',0)
    labeling.addLabelToRow('b',0)
    labeling.addLabelToRow('c',0)

    labeling.addLabelToCol('a',0)
    labeling.addLabelToCol('b',0)
    labeling.addLabelToCol('c',0)

    assert labeling.getLabelsByRow(0) == ['a','b','c']
    assert labeling.getLabelsByCol(0) == ['a','b','c']

    labeling.removeLabelFromRow('b', 0)
    assert labeling.getLabelsByRow(0) == ['a','c']

    labeling.removeLabel('a')
    assert labeling.getLabelsByRow(0) == ['c']
    assert labeling.getLabelsByCol(0) == ['b','c']

    labeling.removeLabelFromCol('c', 0)
    assert labeling.getLabelsByCol(0) == ['b']

    labeling.removeAll()
    assert labeling.getLabelsByRow(0) == []
    assert labeling.getLabelsByCol(0) == []
    
    self.dataset.removeLabeling(labeling)
    
  def test_create_labels(self):
    labeling = GlobalWrapper(self.dataset)

    labeling.addLabelToKey("foo", 1)
    labeling.addLabelToRow("r", 0)
    labeling.addLabelToCol("c", 0)
    labeling.addLabelToKeys("foo", [3,4])
    labeling.addLabelToRows("r", [5,6])
    labeling.addLabelToCols("c", [1,2])
    
    self.dataset.removeLabeling(labeling)

  def test_addLabel(self):
    labeling = GlobalWrapper(self.dataset)

    keys = self.dataset.getKeys()

    even = filter(lambda x : keys.index(x) & 1 == 0, keys)
    odd  = filter(lambda x : keys.index(x) & 1 == 1, keys)

    labeling.addLabelToKeys("foo", even)
    labeling.addLabelToKeys("bar", odd)

    # the frozenset stuff got tossed in because the order of the keys changed for getKeysByLabel
    # i wonder if this is a bug
    self.failUnless(frozenset(labeling.getKeysByLabel("foo")) == frozenset(even),
                     "keys: %s != %s"%(labeling.getKeysByLabel("foo"), even) )
    self.failUnless(frozenset(labeling.getKeysByLabel("bar")) == frozenset(odd))
    
    self.dataset.removeLabeling(labeling)


  def testSuperset(self):

    ds1 = Dataset([[1,2,3],[2,3,4],[3,4,5]])
    ds2 = Dataset([[1,2,3],[2,3,4],[3,4,5]])

    l = GlobalWrapper(ds1, name='names')
    l.labelRows(['d','e','f'])

    l = GlobalWrapper(ds2 , name='names')
    l.labelRows(['a','b','c'])

    v = RowSupersetView(ds1, ds2)

    labels = v.getLabelings()

    assert len(labels) == 2
    assert labels[0].getName() == 'names'
    assert labels[1].getName() == 'names'

  def test_is_row_or_col(self):
    """Check to make sure the is Row and is Col unique functions work
    """
    d1 = Dataset([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
    v1 = RowSubsetView(d1, [0,1])

    # test row labelings
    def test_row_labels_by_view(view):
      row_label = GlobalLabeling(d1)
      row_label.addLabelToRow(view, "one",0)
      row_label.addLabelToRow(view, "two", 1)
      row_mixed_label = GlobalLabeling(d1)
      row_mixed_label.addLabelToRow(view, "one",0)
      row_mixed_label.addLabelToRow(view, "two", 1)
      row_mixed_label.labelCols(view, ["c1", "c2", "c3", ])
  
      self.failIf(row_mixed_label.isRowLabeling(view))
      self.failUnless(row_label.isRowLabeling(view ))

    test_row_labels_by_view(d1)
    test_row_labels_by_view(v1)
    
    # test column labelings
    def test_column_labels_by_view(view):
      column_label = GlobalLabeling(d1)
      column_label.addLabelToCol(view, "one", 0)
      column_label.addLabelToCol(view, "two", 1)
      
      col_mixed_label = GlobalLabeling(d1)
      col_mixed_label.addLabelToCol(view, "one", 0)
      col_mixed_label.addLabelToCol(view, "two", 1)
      col_mixed_label.labelRows(view, ["row"]*view.numRows)
  
      self.failUnless(column_label.isColLabeling(view))
      self.failIf(col_mixed_label.isColLabeling(view))
    test_column_labels_by_view(d1)
    test_column_labels_by_view(v1)
    
  def test_castToGlobalWrapper(self):
    """If you cast a global wrapper to a global wrapper do the labels map to the right dataset?
    """
    d1 = Dataset([[0,0,0],[1,1,1],[2,2,2],[3,3,3]])
    labels_to_add = ["zero","one","two","three"]
    primary_labeling = GlobalWrapper(d1)
    primary_labeling.labelRows(labels_to_add)
    for i in xrange(d1.getNumRows()):
      self.failUnless(primary_labeling.getLabelByRow(i) == labels_to_add[i])
      
    v1 = RowSubsetView(d1, [1,3])
    self.failUnless(v1.getNumRows() == 2)
    view_primary_labeling = castToGlobalWrapper(primary_labeling, dataset=v1, removeLocal=False)
    self.failUnless(view_primary_labeling.getLabelByRow(0) == "one",
                    "Got %s should've been zero" % (view_primary_labeling.getLabelByRow(0)))
    self.failUnless(view_primary_labeling.getLabelByRow(1) == "three")

def suite(**kw):
  return unittest.makeSuite(GlobalLabelingTestCases)

if __name__ == "__main__":
  unittest.main(defaultTest="suite")
  



