########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#  Written By    :  Christopher Hart, Diane Trout, Lucas Scharenbroich
#  Date          :  Febuary 2001
#  Last Modified :  Jul     2004
#

"""
The following class provides a confusion matrix with scoring tools.

The ConfusionMatrix class provides N-dimensional confusion analysis as
well as scoring functions (currently NMI and LA) associated with a confusion
matrix. The confusion matrix is constructed such that each vertex in the
matrix contains a list of elements which are in common between the two
clusters indicated by the indexes of vertex.
"""

import string
import sys

import Numeric

from compClust.mlx.datasets  import PhantomDataset
from compClust.mlx.labelings import Labeling

from compClust.mlx import graphmatching

DEBUG=0

class ConfusionMatrix:
    
    """
    The ConfusionMatrix class' main purpose is to construct a confusion matrix
    between two (or more) labelings or a dataset and perform analysis on the
    confusion matrix.
    """
    
    def __init__(self):
        """
        #
        # self.dimensions:
        #    a list containing the magnitude along each dimensions.  This
        #    corresponds to the number of classes in a given Labeling
        #                
        # self.hypercube:
        #    a dictionary containing the confusion hypercube.  Cell are stored
        #    by index.  The coordinates of the cell in N-dimensional space is
        #    represented by a comma-separated list.  i.e.
        #    hypercube[1][2][3] <=> hypercube['1,2,3']
        #
        # self.numElements:
        #    The total number of elements (genes) stored in the confusion
        #    matrix
        #
        # self.dimensionLabeling:
        #    a Labeling for the axis of the dimension of the Hypercube.
        #    Operations on the confusion hypercube should make use of axis
        #    labels, not indexes
        #
        # self.index2cell:
        #    The index of a label in the labeling into this list will return
        #    the string key of its cell in the hypercube
        #
        # self.rowClassNames & self.colClassNames:
        #    A 2-way dictionary which maps the class numbers to their names
        #    and vice-versa.  Aliases of self.classNames[0] and
        #    self.classNames[1]
        #
        # self.classNames
        #    A list of class name <-> class number dictionary, one for
        #    each dimension
        #
        """
        self.dimensions        = []
        self.hypercube         = {}
        self.numElements       = 0
        self.dimensionLabeling = None
        self.index2cell        = []
        self.rowClassNames     = {}
        self.colClassNames     = {}
        self.classNames        = []
        
    def getNumElements(self):
        """
        Returns the total number of elements stored in the confusion hypercube.
        """

        return self.numElements

    def getDimensionalLabeling(self):
        """
        Returns a list containing the labels for each dimension along the
        hypercube.
        """
        
        return self.dimensionLabeling


    def projectConfusionHypercube(self, labels):
        """
        Returns another hypercube built from the dimensions of the original
        confusion hypercube specified by 'labels'.  Equivalent to projecting
        the hypercube onto the dimensions passed.
        """

        originalLabels = self.dimensionLabeling
        
        #
        # Build a list of the dimensions to keep
        #
        
        keep       = []
        dimensions = []
        
        for i in range(len(labels)):
            label = str(labels[i])
            if label in originalLabels:
                index = originalLabels.index(label)
                dimensions.append(self.dimensions[index])
                keep.append(index)
                

        keep.sort()
        
        #
        # Filter out all the coordinates in the unwanted dimensions
        #

        newHypercube  = {}
        newIndex2Cell = []
        hypercube     = self.hypercube

        #
        # Fill in the data in the new Confusion Matrix by appending the data
        # which will fall into the same bin
        #

        for key in hypercube.keys():
            newKey     = tuple([key[x] for x in keep])
            newData    = newHypercube.get(newKey, [])
            newData   += self.getConfusionHypercubeCell(key)

            newHypercube[newKey] = newData

        #
        # Rebuild the index2cell list for the new matrix
        #

        for i in range(len(self.index2cell)):
            key        = self.findCellCoordinates(i)
            newKey     = tuple([key[x] for x in keep])
            newIndex2Cell.append(newKey)

        #
        # Create and initialize a brand new confusion matrix
        #

        confMat                   = ConfusionMatrix()
        confMat.dimensions        = dimensions
        confMat.hypercube         = newHypercube
        confMat.numElements       = self.numElements
        confMat.dimensionLabeling = labels
        confMat.index2cell        = newIndex2Cell
        confMat.rowClassNames     = self.rowClassNames
        confMat.colClassNames     = self.colClassNames
        confMat.classNames        = self.classNames                

        return confMat

    def getAgreementList(self):
        """
        Returns a list of numElements of which each entry contains either
        a 1 or 0 depending whether or not the dimensions of the hypercube
        agree on the point's classification.  This method is only valid
        for 2D confusion matrices.

        A perfect agreement would return a list of all 1's.
        """

        #
        # Find out which clusters are common to each other
        #

        adjMatrix = self.getAdjacencyMatrix()

        agreeList = [0] * self.getNumElements()
        
        for i in range(self.getNumElements()):
            coord = self.findCellCoordinates(i)

            #
            # If that entry is non-zero then it is an agreed upon cluster
            #

            if adjMatrix[coord] != 0:
                agreeList[i] = 1

        return agreeList

                
    def getStarburst(self, index):
        """
        Returns a list of lists containing all the elements in the starburst
        centered on the cell containing index.  The starburst can be
        conceptualizes as a series of rays expanding along each dimension of
        the hypercube from the center point.  For each cell these rays touch,
        if it contains data, that data is appended to a list.

        This operation is useful to determine what data, while not perfectly
        associated with 'index', is considered 'related' to some degree.
        """

        starburst = []

        #
        # Get a list of numbers specifying this cell
        #
        
        coords = list(self.findCellCoordinates(index))

        #
        # Now run along each dimension, gobbling up index lists
        #
        
        for i in range(len(coords)):

            #
            # Save the index
            #
            
            thisCell = coords[i]

            #
            # Hold all other coordinates fixed, and run along the range of
            # the dimension.
            
            for cell in range(self.dimensions[i]):
                if thisCell == cell:
                    continue
                coords[i] = cell
                cellData = self.getConfusionHypercubeCell(tuple(coords))
                starburst += cellData

            #
            # Reset the index
            #
            
            coords[i] = thisCell

        #
        # Return the list of indicies
        #

        return starburst

    def getInverseStarburst(self, index):
        """
        Similar to getStarburst(), but returns a list of cell data not along
        the dimension axis from the cell containing index.  This is not a
        proper inverse since this set also does not contain the cell to which
        index belongs.

        This operation tell what data is 'unrelated' to index.
        """

        #
        # Retrieve the starburst for this index and then fill the new list
        # with every index not in that list or the list of indicies in the
        # 'home' cell
        #

        starburst = self.getStarburst(index)
        homeCell  = self.getConfusionHypercubeCell( \
            self.findCellCoordinates(index))
        invStarburst = []
        
        for i in range(self.numElements):
            if i not in starburst and i not in homeCell:
                invStarburst.append(i)

        return invStarburst
    
    
    def getConfusionHypercubeCell(self, cellCoordinates):
        """
        Returns the list of indices held in a node of the hypercube.  If the
        node does not exist, an empty list is returned.

        The cellCoordinates value is a tuple of labels
        """

        if len(cellCoordinates) != len(self.dimensions):
            return []

        return self.hypercube.get(cellCoordinates, [])


    def removeCellFromHypercube(self, cellCoordinates):
        """
        Given a cell's coordinates, removes that cell from the hypercube
        """

        #
        # Now delete the cell itself from the hypercube dictionary
        #
        
        if self.hypercube.has_key(cellCoordinates):
            del self.hypercube[cellCoordinates]
            
    def removeIndexFromHypercube(self, index):

        #
        # Look up the cell this index resides in and get the data
        #

        cell = self.findCellCoordinates(index)
        data = getConfusionHypercubeCell(cell)

        #
        # Now remove the index from the data list
        #

        del data[index]


    def findCellCoordinates(self, index):
        """
        Finds and return the coordinates of the cell which contains the index
        in question.  The coordinates are returned as a list of integers
        suitable to be passed to getConfusionHypercubeCell().
        """

        return self.index2cell[index]
    
    
    def createConfusionHypercubeFromLabeling(self, labelings ):
        """
        A Confusion Hypercube is a generalization of the confusion matrix which
        allows for any number of labelings to be analyzed at the same time.
        This is a realization of the full Cartesian product of the classes
        defined in the labelings.

        The ability to ask questions about more than two datasets at a time
        is a valuable tool and can be used for sophisticated anaysis.
        """

        if labelings is None or len(labelings) == 0:
            return

        dimensionLabels = []
        count = 1
        for labeling in labelings:
            if labeling.getName() is not None:
                dimensionLabels.append(labeling.getName())
            else:
                dimensionLabels.append('Dimension ' + str(count))
            count += 1
            
        if len(dimensionLabels) != len(labelings):
            return

        #
        # Create a labeling for the axis of the hypercube
        # 

        self.dimensionLabeling = dimensionLabels
            
        classes    = []
        dimensions = []
        hypercube  = {}
        
        #
        # We can only look at elements which exist in all the labelings,
        # so find the one with the smallest number of elements
        #
    
        numberOfElements = sys.maxint
        for labeling in labelings:
            num = labeling.getDataset().getNumRows()
            if (num < numberOfElements):
                numberOfElements = num

                
        #
        # Build the class name <-> index dictionary
        #

        for labeling in labelings:
            labels = labeling.getLabelByRows()
            dim = len(labeling.getLabels())
            dimensions.append(dim)

            classNames = {}
            classList = labeling.getLabels()
            if len([ x for x in labels if x is not None]) < labeling.getDataset().getNumRows():
                classList.append(None)
                dim +=1
            
            for index in range(dim):

                #
                # Allow any hashable object to be a label
                #
                
                className = str(classList[index])
                classNames[className] = index
                classNames[index] = className

            #
            # append a list of the class label corresponding to each class
            # name in a list
            #

            classes.append([ classNames[str(x)] for x in labels])
            self.classNames.append(classNames)
            
        #
        # Now fill in the dictionary.  The keys are simply the string
        # representation of the coordinates in the hypercube
        #

        self.numElements = numberOfElements
        self.dimensions  = dimensions

        for k in xrange(numberOfElements):
            class_keys = []
            for c in classes:
                class_keys.append(c[k])
                
            key  = tuple(class_keys)
            key_list = hypercube.get(key, [])
            key_list.append(k)
            hypercube[key] = key_list
            self.index2cell.append(key)

        #
        # Set the rowClassNames and colClassNames to the classNames list
        #

        self.rowClassNames = self.classNames[0]
        self.colClassNames = self.classNames[0]
        if len(labelings) > 1:
            self.colClassNames = self.classNames[1]

        self.hypercube = hypercube


    def createConfusionMatrixFromLabeling(self, labeling1, labeling2):
        """
        A confusion matrix is constructed from the two labelings
        labeling1 and labeling2.
        """

        self.createConfusionHypercubeFromLabeling([labeling1, labeling2])
        self.numRows, self.numCols = Numeric.shape(self.getCounts())


    def createConfusionMatrixFromFile(self, clusteringFile1, clusteringFile2):

        """
        A confusion matrix is constructed from the two clustered
        files clusteringFile1 and clusteringFile2.  clusteringFile1
        clusters are arranged across the rows and clusteringFile2
        clusters are arranged across the columns of the confusion
        matrix.  We assume each file contains a list of cluster labels,
        one per line.  The confusion matrix will on be constructed for
        data which is shared between the two clusterings.  

        The constructed confusion matrix, instead of storing straight
        numeric counts, stores lists of elements shared between each
        pair of clusters.  This adds quite a bit of exploratory power
        to the confusion matrix.
        """

        from compClust.util.FileIO import readLabelFile
        
        labels    = readLabelFile(clusteringFile1)
        num       = len(labels)

        phantom   = PhantomDataset(num, 1)

        labeling1 = Labeling(phantom)
        labeling1.labelRows(clusteringFile1)
        
        labeling2 = Labeling(phantom)
        labeling2.labelRows(clusteringFile2)

        self.createConfusionMatrixFromLabeling(labeling1, labeling2)
        

    def __countHelper(self, dims, partialKey):
        """
        Recursive subroutine which traverses the hypercube along each
        dimension and creates a nested list structure of the counts of the
        cells.  This is a O(n^d) algorithm where d is the number of dimensions
        of the hypercube and n in the magnitude of each dimension.  Thus, this
        routine should be called sparingly
        """
        
        if (len(dims) == 0):
            return len(self.getConfusionHypercubeCell(tuple(partialKey)))

        counts = []
        for index in range(dims[0]):
            counts.append(self.__countHelper(dims[1:], partialKey+[index]))

        return counts

    def getCounts(self):
        return self.getHypercubeCounts()
    
    def getHypercubeCounts(self):

        if len(self.dimensions) < 1:
            return []
        else:
            return self.__countHelper(self.dimensions, [])
        
    def printCounts(self, labels=0, outputStream=sys.stdout):
        """
        Makes a pretty print out of the confusion matrix, even labels the
        axes. If you pass in labels=0, then no labels will be printed.
        outputStream allows for the output of the function to be redirected to
        any open stream
        """

        if len(self.dimensions) > 2:
            print "Can't print confusion matrices with more than 2 dimensions"
            return
        
        countMatrix = self.getHypercubeCounts()
        rowMapping  = map( None, self.rowClassNames.values(), self.rowClassNames.keys() )
        colMapping  = map( None, self.colClassNames.values(), self.colClassNames.keys() )
  
        rowMapping.sort()
        colMapping.sort()

        #
        # Need to strip the label -> integer pairs from the list
        # 

        colMapping = colMapping[0:len(colMapping)/2]
        rowMapping = rowMapping[0:len(rowMapping)/2]

        if labels==1:
            outputStream.write("\t\t\t")
            
            #
            # print the columan class Names 
            #
            
            for colName in colMapping:
                outputStream.write("%s\t"%(string.strip(colName[1])))
            outputStream.write("\n")

            #
            # print a seperator line
            #

            outputStream.write("\t\t+--\t")
            for colName in colMapping:
                outputStream.write("---\t")
            outputStream.write("\n")

            #
            # print the matrix
            #

            count = 0
            for row in countMatrix:
                #print "\t%s\t|"% (self.rowClassNames[count]),
                outputStream.write("\t%s\t|"% (string.strip(rowMapping[count][1])))
                for item in row: 
                    #print "\t  %s"% (item),
                    outputStream.write("\t  %s"% (item))
                #print ""
                outputStream.write('\n')
                count = count +1

        else:
            for row in countMatrix:

                if len(self.dimensions) == 2:
                    for item in row:
                        # print str(item)+"\t",
                        outputStream.write(str(item)+"\t")
                    outputStream.write('\n')
                else:
                    outputStream.write(str(row)+"\t")
            outputStream.write('\n') 
            
    def getAdjacencyMatrix(self):
        
        """
        Returns a numeric array with a 1 indicating the clusters corresponding
        to the elements indices are corresponding clusters - zero everywhere
        else.
        """
        score, adjacency_matrix = graphmatching.match(self.getCounts())
        return adjacency_matrix
    
    def getAdjacencyList(self):
        
        """
        Returns a list of tuples indicating which classes correspond with
        each other.
        """
        
        adjMatrix = self.getAdjacencyMatrix()
        adjList = []
        # FIXME: how bad of an error is it for there to be no adjacency matrix
        if len(Numeric.shape(adjMatrix)) == 0:
            raise ValueError("unable to compute adjacency matrix")
        for r in range(0, Numeric.shape(adjMatrix)[0]):
            for c in range(0, Numeric.shape(adjMatrix)[1]):
                if adjMatrix[r][c] > 0:
                    adjList.append((self.rowClassNames[r],
                                    self.colClassNames[c]))
                    
        return (adjList)
    
    def NMI(self):
        """
        Returns the NMI score of the confusion matrix.
        """
        from compClust.score import NMI
        return NMI(self)

    def averageNMI(self):
        """
        Returns the average NMI score between the confusion matrix and it's
        transpose.
        """
        from compClust.score import averageNMI
        return averageNMI(self)

    def transposeNMI(self):
        """
        Returns the NMI score of the transposed confusion matrix.
        """
        from compClust.score import transposeNMI
        return transposeNMI(self)
    
    def linearAssignment(self):
        """
        Returns the linear assignment score for a given matrix.
        """
        from compClust.score import LinearAssignment
        return LinearAssignment(self)
