#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
# Author:  Christopher Hart
# Date  :  September 2001

"""
Test suite for the ConfusionMatrix module
"""


# standard modules
import inspect
import os
import string
import unittest

import MLab
try:
  import matplotlib.numerix as nx
except ImportError, e:
  import Numeric as nx

from compClust.score.ConfusionMatrix import ConfusionMatrix
from compClust.mlx.datasets import Dataset
from compClust.mlx.datasets import PhantomDataset
from compClust.mlx.labelings import Labeling

class ConfusionMatrixTestCases(unittest.TestCase):
    
    """
    Simple Test Cass for the confusion matrix
    
    """
    
    
    def setUp(self):
        
        """
        sets up the test case using a small synthetic dataset a groundTruth
        labeling and a permuted groundTruth labeling
        
        Testfiles in test/
      
           clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt
           clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt
           synth_t_05c0_p_0075_d_03_v_1d0.txt

        Changelog for permuted clustering:

           data ID   referenceClass  permutedClass
           ---------------------------------------
           1              1              5
           29             1              5
           67             1              5
           24             4              2
           69             4              2
           all            3              4
           all            4              3
           
           Correct Confusion Matrix:

                          1       2       3       4       5       
                +--     ---     ---     ---     ---     ---     
        1       |         13      0       0       0       0
        2       |         0       17      0       2       0
        3       |         0       0       0       19      0
        4       |         0       0       7       0       0
        5       |         3       0       0       0       14



           """

        # load the datain
        source = os.path.realpath(inspect.getsourcefile(ConfusionMatrixTestCases))
        self.datapath = os.path.split(source)[0]
        
        datasetName = os.path.join(self.datapath, 'synth_t_05c0_p_0075_d_03_v_1d0.txt')
        datasetStream = open(datasetName, 'r')
        self.dataset = Dataset(datasetStream)
        
        refLabelingName = os.path.join(self.datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt')

        self.refLabeling = Labeling(self.dataset)
        self.refLabeling.labelRows(refLabelingName)
        
        permutedLabelingName = os.path.join(self.datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt')

        self.permutedLabeling = Labeling(self.dataset)
        self.permutedLabeling.labelRows(permutedLabelingName)

        self.confMatByLabeling = ConfusionMatrix()
        self.confMatByFile     = ConfusionMatrix()

        self.confMatByLabeling.createConfusionMatrixFromLabeling(self.refLabeling, self.permutedLabeling)
        self.confMatByFile.createConfusionMatrixFromFile(refLabelingName, permutedLabelingName)


    def testMarginals(self):

        """
        This checks to make sure the confusion matrix is being constructed
        properlywith the labeling
        """

        # self.confMatByLabeling.printCounts()
        
        # Check the marginal sums

        rowSum = nx.array((16, 21, 14,  7, 17))
        colSum = nx.array((17, 13, 19,  7, 19))

        # "\tChecking Marginals...",

        self.failUnless(((MLab.sum(self.confMatByLabeling.getCounts()) == colSum) and
                         (MLab.sum(nx.transpose(self.confMatByLabeling.getCounts()))) == rowSum),
                        'Marginal row sum Failed on Labeling loaded confusion matrix')
        
        self.failUnless(((MLab.sum(self.confMatByFile.getCounts()) == colSum) and
                         (MLab.sum(nx.transpose(self.confMatByFile.getCounts()))) == rowSum),
                        'Mariginal col sum Failed on File loaded confusion matrix')


    def testHyper(self):
        """
        Checks the hypercube functions.
        """

        ds = PhantomDataset(100,1)
        cm = ConfusionMatrix()

        l1 = Labeling(ds, 'x')
        l2 = Labeling(ds, 'y')
        l3 = Labeling(ds, 'z')

        l1.addLabelToRows('1', range(0,100,2))
        l1.addLabelToRows('2', range(1,100,2))

        l2.addLabelToRows('1', range(0,100,3))
        l2.addLabelToRows('2', range(1,100,3))
        l2.addLabelToRows('3', range(2,100,3))

        l3.addLabelToRows('1', range(0,100,4))
        l3.addLabelToRows('2', range(1,100,4))
        l3.addLabelToRows('3', range(2,100,4))
        l3.addLabelToRows('4', range(3,100,4))

        cm.createConfusionHypercubeFromLabeling([l1,l2,l3])

        tmp = cm.projectConfusionHypercube(['x','y'])
        a_list = tmp.getAgreementList()

        assert tmp.getCounts() == [[17, 17, 16], [17, 16, 17]]
        
    def testMatrix(self):

        """
        Checks the real arrangement of the confusion matrix

        FIXME:  currently this fails on correct confusion matrixes with
        FIXME:  different permutations...this shouldn't be too much of a
        FIXME:  problem because things should always be loaded the same way
        """
        rightAnswer = [[ 3,  0, 14, 0, 0],
                       [13,  0,  0, 0, 0],
                       [ 0, 19,  0, 0, 0],
                       [ 0,  0,  0, 7, 0],
                       [ 0,  2,  0, 0, 17]]


        # "\tChecking Confusion Matrix...",
        self.failUnless(nx.transpose(self.confMatByLabeling.getCounts()) == nx.transpose(self.confMatByFile.getCounts()) == rightAnswer,
                        'ConfusionMatrix is incorrect')
        


    def testAdjacencyCalculations(self):
        """Does confusion matrix generate a valid adjecency matrix?
        """
        a = self.confMatByLabeling.getAdjacencyMatrix()
        b = self.confMatByFile.getAdjacencyMatrix()

        for matrix, axis in [(a,0), (a,1), (b,0), (b,1)]:
          self.failUnless(nx.sum(matrix,axis) == [1,1,1,1,1],
                          msg="%s was not an adjaceny matrix" % (str(matrix)))

    def testAgreement(self):

        rightAnswer = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                       1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
                       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
                       0, 1, 1, 1, 1, 1, 1]

        agree1 = self.confMatByLabeling.getAgreementList();
        agree2 = self.confMatByFile.getAgreementList();

        # "\tChecking agreement lists...",
        self.failUnless((agree1 == agree2 == rightAnswer),
                        'Failed rowLabels check - class-to-index dict broken')
        

    def testRowLabels(self):

        for confusion in [ self.confMatByLabeling, self.confMatByFile ]:
          rowClassNames = confusion.rowClassNames
          for key, value in rowClassNames.items():
            self.failUnless(rowClassNames[value] == key, 
                            msg="value %s didn't return key %s, in %s" % (value, key, rowClassNames))

    def testColLabels(self):

        for confusion in [ self.confMatByLabeling, self.confMatByFile ]:
          colClassNames = confusion.colClassNames
          for key, value in colClassNames.items():
            self.failUnless(colClassNames[value] == key, 
                            msg="value %s didn't return key %s, in %s" % (value, key, colClassNames))

def suite(**kw):
    return unittest.makeSuite(ConfusionMatrixTestCases)

if __name__ == "__main__":
    unittest.main(defaultTest="suite")
