#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
# Author:  Christopher Hart
# Date  :  September 2001

"""
Test suite for the ConfusionMatrix module
"""


# standard modules
import inspect
import os
import string
import unittest

import MLab
import Numeric

from compClust.score.ConfusionMatrix2 import ConfusionMatrix, tuple_stringify, tuple_unstringify
from compClust.mlx.datasets import Dataset
from compClust.mlx.datasets import PhantomDataset
from compClust.mlx.labelings import Labeling

class ConfusionMatrix2TestCases(unittest.TestCase):
    
  """
  Simple Test Cass for the confusion matrix
  """
  
  
  def setUp(self):
    """
    sets up the test case using a small synthetic dataset a groundTruth
    labeling and a permuted groundTruth labeling
       
    Testfiles in test/
      
      clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt
      clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt
    synth_t_05c0_p_0075_d_03_v_1d0.txt

    Changelog for permuted clustering:

       data ID   referenceClass  permutedClass
       ---------------------------------------
       1              1              5
       29             1              5
       67             1              5
       24             4              2
       69             4              2
       all            3              4
       all            4              3
       
       Correct Confusion Matrix:

                      1       2       3       4       5       
            +--     ---     ---     ---     ---     ---     
    1       |         13      0       0       0       0
    2       |         0       17      0       2       0
    3       |         0       0       0       19      0
    4       |         0       0       7       0       0
    5       |         3       0       0       0       14



    """

    # load the datain
    source = os.path.realpath(inspect.getsourcefile(ConfusionMatrix2TestCases))
    self.datapath = os.path.split(source)[0]
    
    datasetName = os.path.join(self.datapath, 'synth_t_05c0_p_0075_d_03_v_1d0.txt')
    datasetStream = open(datasetName, 'r')
    self.dataset = Dataset(datasetStream)
    
    refLabelingName = os.path.join(self.datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt')

    self.refLabeling = Labeling(self.dataset)
    self.refLabeling.labelRows(refLabelingName)
    
    permutedLabelingName = os.path.join(self.datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt')

    self.permutedLabeling = Labeling(self.dataset)
    self.permutedLabeling.labelRows(permutedLabelingName)

    confMatLabelings = [ self.refLabeling, self.permutedLabeling ]
    self.confMatByLabeling = ConfusionMatrix(confMatLabelings)
    self.confMatWebSafe = ConfusionMatrix(confMatLabelings,web_safe=True)
        
        

  def testMarginals(self):

    """
    This tests to make sure the confusion matrix is being constructed
    properlywith the labeling
    """

    # self.confMatByLabeling.printCounts()
    
    # Check the marginal sums

    rowSum = Numeric.array((16, 21, 14,  7, 17))
    colSum = Numeric.array((17, 13, 19,  7, 19))

    # "\tChecking Marginals...",

    self.failUnless(((MLab.sum(self.confMatByLabeling.getMatrix()) == colSum) and
                     (MLab.sum(Numeric.transpose(self.confMatByLabeling.getMatrix()))) == rowSum),
                    'Marginal row sum Failed on Labeling loaded confusion matrix')
    self.failUnless(((MLab.sum(self.confMatWebSafe.getMatrix()) == colSum) and
                     (MLab.sum(Numeric.transpose(self.confMatWebSafe.getMatrix()))) == rowSum),
                    'Marginal row sum Failed on Labeling loaded confusion matrix')
        

# only provided in first confustion matrix api
#     def testHyper(self):
#         """
#         Checks the hypercube functions.
#         """
# 
#         ds = PhantomDataset(100,1)
# 
#         l1 = Labeling(ds, 'x')
#         l2 = Labeling(ds, 'y')
#         l3 = Labeling(ds, 'z')
# 
#         l1.addLabelToRows('1', range(0,100,2))
#         l1.addLabelToRows('2', range(1,100,2))
# 
#         l2.addLabelToRows('1', range(0,100,3))
#         l2.addLabelToRows('2', range(1,100,3))
#         l2.addLabelToRows('3', range(2,100,3))
# 
#         l3.addLabelToRows('1', range(0,100,4))
#         l3.addLabelToRows('2', range(1,100,4))
#         l3.addLabelToRows('3', range(2,100,4))
#         l3.addLabelToRows('4', range(3,100,4))
# 
#         cm = ConfusionMatrix([l1, l2, l3])
# 
#         tmp = cm.projectConfusionHypercube(['x','y'])
#         a_list = tmp.getAgreementList()
# 
#         assert tmp.getMatrix() == [[17, 17, 16], [17, 16, 17]]
        
  def testMatrix(self):

    """
    Checks the real arrangement of the confusion matrix

    FIXME:  currently this fails on correct confusion matrixes with
    FIXME:  different permutations...this shouldn't be too much of a
    FIXME:  problem because things should always be loaded the same way
    """
    rightAnswer = [[ 3,  0, 14, 0, 0],
                   [13,  0,  0, 0, 0],
                   [ 0, 19,  0, 0, 0],
                   [ 0,  0,  0, 7, 0],
                   [ 0,  2,  0, 0, 17]]


    # "\tChecking Confusion Matrix...",
    self.failUnless(Numeric.transpose(self.confMatByLabeling.getMatrix()) == rightAnswer,
                    'ConfusionMatrix is incorrect')
    self.failUnless(Numeric.transpose(self.confMatWebSafe.getMatrix()) == rightAnswer,
                    'ConfusionMatrix is incorrect')
        


  def testAdjacencyCalculations(self):

    rightAnswer = Numeric.array([[1, 0, 0, 0, 0],
                                 [0, 1, 0, 0, 0],
                                 [0, 0, 0, 1, 0],
                                 [0, 0, 1, 0, 0],
                                 [0, 0, 0, 0, 1]])


    # "\tChecking Adjacency Calculations...",
    a = self.confMatByLabeling.getAdjacencyMatrix()
    b = self.confMatWebSafe.getAdjacencyMatrix()
    
    c = Numeric.sum(Numeric.sum(a == rightAnswer))
    d = Numeric.sum(Numeric.sum(b == rightAnswer))
    
    self.failUnless((c == d == 25),
                    'Failed Adjancency Test')

# Was used in first ConfusionMatrix api
#     def testAgreement(self):
# 
#         rightAnswer = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#                        1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
#                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
#                        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
#                        0, 1, 1, 1, 1, 1, 1]
# 
#         agree1 = self.confMatByLabeling.getAgreementList();
# 
#         # "\tChecking agreement lists...",
#         self.failUnless((agree1 == rightAnswer),
#                         'Failed rowLabels test - class-to-index dict broken')
#         

  def testScoring(self):
    """Test to see if the scoring functions work correclty
    """
    self.failUnless(round(self.confMatByLabeling.NMI(),4) == round(0.87734170819,4))
    self.failUnless(round(self.confMatByLabeling.transposeNMI(),4) == round(0.877418115328, 4))
    self.failUnless(round(self.confMatByLabeling.averageNMI(),4) == round(0.877379911759,4))
    self.failUnless(round(self.confMatByLabeling.linearAssignment(),4) == round(0.933333333333,4))
      
    self.failUnless(round(self.confMatWebSafe.NMI(),4) == round(0.87734170819,4))
    self.failUnless(round(self.confMatWebSafe.transposeNMI(),4) == round(0.877418115328, 4))
    self.failUnless(round(self.confMatWebSafe.averageNMI(),4) == round(0.877379911759,4))
    self.failUnless(round(self.confMatWebSafe.linearAssignment(),4) == round(0.933333333333,4))

  def testTupleStringify(self):
    """Does the Tuple Stringify/Unstringify work?
    """
    stringify_test = [((1,2), "(1,2)"),
                      (('a', 'b'), "('a','b')",)]
    for test, result in stringify_test:                      
      self.failUnless(tuple_stringify(test) == result)
    
    unstringify_test = [("(1,2)", (1, 2)),
                        ("('a','b')", ("a", "b"))]
    for test, result in unstringify_test:                      
      self.failUnless(tuple_unstringify(test) == result)
                        
                        

  def testCho(self):
    """ticket:20 at least for the cho dataset, the sum of the confusion matrix is not equal to the 
    size of the dataset.
    """
    from compClust.util.LoadExample import LoadCho
    cho = LoadCho()
    diagem_clustering = cho.getLabeling('diagem clusters')
    self.failUnless(diagem_clustering is not None)
    cho_clustering = cho.getLabeling('cho clusters')
    self.failUnless(cho_clustering is not None)
    cm = ConfusionMatrix([diagem_clustering, cho_clustering])
    self.failUnless(Numeric.sum(cm.getMatrix().flat) <= cho.getNumRows())
   
def suite(**kw):
  return unittest.makeSuite(ConfusionMatrix2TestCases)
  
if __name__ == "__main__":
    unittest.main(defaultTest="suite")
