#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
A test suite for clusterCenterDistances.py
"""

# standard modules
import inspect
import os
import string
import unittest

try:
    import gracePlot
except:
    gracePlot = None
    #print "gracePlot Not installed - no visualization"

import Numeric
import MLab

from compClust.score import ClusterCenterDistances
from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling
from compClust.mlx.views import RowPCAView
from compClust.mlx.models import constructMixtureOfDiagonalGaussiansFromLabeling

class ClusterCenterDistancesTestCases(unittest.TestCase):

    """
    calculates the distances between some cluster means to test
    clusterCenterDistances...
    
    """
    
    def setUp(self):
        
        """
        sets up the test case using a small synthetic dataset a groundTruth labeling
        and a permuted groundTruth labeling
        
        Testfiles in test/

           clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt
           clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt
           clust_t_05c0_p_0075_d_03_v_1d0_a_mergedCluster.txt
           clust_t_05c0_p_0075_d_03_v_1d0_a_splitCluster.txt
           synth_t_05c0_p_0075_d_03_v_1d0.txt

        
        """

        # load the data in
        source = os.path.realpath(inspect.getsourcefile(ClusterCenterDistancesTestCases))
        datapath = os.path.split(source)[0]

        self.dataset = Dataset(os.path.join(datapath, 'synth_t_05c0_p_0075_d_03_v_1d0.txt'))

        self.refLabeling = Labeling(self.dataset)
        self.refLabeling.labelRows(os.path.join(datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_reference.txt'))

        self.permutedLabeling = Labeling(self.dataset)
        self.permutedLabeling.labelRows(os.path.join(datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_permuted.txt'))

        self.mergedLabeling = Labeling(self.dataset)
        self.mergedLabeling.labelRows(os.path.join(datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_mergedCluster.txt'))

        self.splitLabeling = Labeling(self.dataset)
        self.splitLabeling.labelRows(os.path.join(datapath, 'clust_t_05c0_p_0075_d_03_v_1d0_a_splitCluster.txt'))

        self.tolerance = .005


    def tearDown(self):

        pass

    def checkClusterDistanceCalculationsIdenity(self):

        """
        Tests to make sure that if you calculate the
        clusterCenterDistance between to identical labelings the
        result is 1.0
        """

        # "\t Checking Identity Calculation..."
        score = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.refLabeling)
        self.failUnless(score == 1.0)
        

    def checkClusterDistanceCalculationsSymmetry(self):

        """ Checks to make sure the calculations are symetric - ie
        clusterCenterDistances(X,Y) == clusterCenterDistance(Y,X)
        """

        # "\t Checking Symmetry Calculation..."
        score1 = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.permutedLabeling)
        score2 = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.permutedLabeling,self.refLabeling )

        self.failUnless(score1 == score2)


    def checkClusterDistancesWithPermutedDataset(self):

        """
        Checks to make sure the calculations are correct when
        compared with the permuted dataset
        """
        # "\t Checking exact Calculation with permuted labeling..."
        rightScore = .974
        score = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.permutedLabeling)
        self.failUnless (MLab.absolute(score - rightScore) < self.tolerance)

    def checkClusterDistancesWithMergedLabeling(self):

        # "\t Checking exact Calculation with merged labeling..."
        rightScore = .883
        score = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.mergedLabeling)
        self.failUnless (MLab.absolute(score - rightScore) < self.tolerance)
        
    def checkClusterDistancesWithSplitLabeling(self):

        # "\t Checking exact Calculation with split labeling..."
        rightScore = .996
        score = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.splitLabeling)
        self.failUnless (MLab.absolute(score - rightScore) < self.tolerance)
        
    def visualizeResults(self):

        """ generates a PCA projected representation of where the
        cluster means are and the distance calculated """

        # "\t Checking visualizing mean space..."
        g = gracePlot.gracePlot()
        pca = RowPCAView(self.dataset)
        scorePer = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.permutedLabeling)
        scoreMerged = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.mergedLabeling)
        scoreSplit  = ClusterCenterDistances.clusterMeansDistance (self.dataset, self.refLabeling, self.splitLabeling)

        ref = constructMixtureOfDiagonalGaussiansFromLabeling(self.dataset, self.refLabeling).means
        per = constructMixtureOfDiagonalGaussiansFromLabeling(self.dataset, self.permutedLabeling).means
        split = constructMixtureOfDiagonalGaussiansFromLabeling(self.dataset, self.splitLabeling).means
        merged = constructMixtureOfDiagonalGaussiansFromLabeling(self.dataset, self.mergedLabeling).means
        
        
        g.hold(1)
        g.plot(Numeric.dot(ref, pca.matrix)[:,0], Numeric.dot(ref, pca.matrix)[:,1], linetype = 'none')
        g.plot(Numeric.dot(per, pca.matrix)[:,0], Numeric.dot(per, pca.matrix)[:,1], linetype = 'none')
        g.plot(Numeric.dot(split, pca.matrix)[:,0], Numeric.dot(split, pca.matrix)[:,1], linetype = 'none')
        g.plot(Numeric.dot(merged, pca.matrix)[:,0], Numeric.dot(merged, pca.matrix)[:,1], linetype = 'none') 

        g.legend (['reference Means',
                   'permuted labeling: score = %3.3f'%(scorePer),
                   'split labeling  : score = %3.3f'%(scoreSplit),
                   'merged  labeling  : score = %3.3f'%(scoreMerged),
                   ])

        g.title ('PCA Visualization of Mean Space ')
        
def suite(**kw):
    suite = unittest.TestSuite()
    suite.addTest(ClusterCenterDistancesTestCases("checkClusterDistanceCalculationsIdenity"))
    suite.addTest(ClusterCenterDistancesTestCases("checkClusterDistancesWithPermutedDataset"))
    suite.addTest(ClusterCenterDistancesTestCases("checkClusterDistanceCalculationsSymmetry"))
    suite.addTest(ClusterCenterDistancesTestCases("checkClusterDistancesWithMergedLabeling"))
    suite.addTest(ClusterCenterDistancesTestCases("checkClusterDistancesWithSplitLabeling"))
    if gracePlot is not None:
        suite.addTest(ClusterCenterDistancesTestCases("visualizeResults"))

    return(suite)

if __name__ == "__main__":
    unittest.main(defaultTest="suite")
