#!/usr/bin/env python2.2
########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
This script is a pared down combined version of wrapper_confusionMatrix and clusterCenterDistances.

Author: Christopher Hart
Date  : August, 2001
"""

# standard modules
import string
import os
import sys

# common modules
import MLab
import Numeric

# MLS modules
from compClust.score import ConfusionMatrix
from compClust.score import ClusterCenterDistances
from compClust.util import Usage
from compClust.util import DistanceMetrics
from compClust.util import WrapperUtil

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling

def main(opts, argv):

    """
    Usage: labelingsSummary <options> <fileList>

    required parameters:

    -d or --dataset-path <path-to-dataset>

    optional parameters:

    -r or --reference reference file to compare to
    
    -f or --input      list of files to compare (use this instead of a long argument list)

    -m or --metric       <distance metric to use> defaults to euclidean
                        [euclidean, manhattan, maximum, correlation]
    
    --full-filename    output full paths with filename

    --mean             only print out the mean (instead of mean+-std

    --median           only print out the median

    

    """

    inputfile = opts.get('-f', opts.get('--input-file', None))
    if inputfile:
        fileList = map (lambda x: string.strip(x), open(inputfile, 'r').readlines())
    else:
        fileList = argv

        
    if opts.has_key('-h') or opts.has_key('--help') or len(fileList)==0:
        Usage.showHelp(main, exit=1)

    datasetName = opts.get("-d", opts.get("--dataset-path"))
    metric = opts.get('-m', opts.get('--metric','euclidean'))
    referenceName = opts.get('-r', opts.get('--reference', None))
    
    # set up the dataset

    dataset = Dataset(datasetName)

    count = 0
    distances = {}
    confMatScores = {}

    
    for labelingName1 in fileList:
        labeling1 = Labeling(dataset)
        labeling1.labelRows(labelingName1)
        if referenceName:
            referenceLabels = Labeling(dataset)
            referenceLabels.labelRows(referenceName)

            sys.stderr.write('working on: %s %s..\n'%(os.path.basename(referenceName),
                                                      os.path.basename(labelingName1)))


            confMatrix =  ConfusionMatrix()
            confMatrix.createConfusionMatrixFromLabeling(referenceLabels, labeling1)
            confMatScores[(referenceName, labelingName1)] = [confMatrix.NMI(),
                                                             confMatrix.transposeNMI(),
                                                             confMatrix.averageNMI(),
                                                             confMatrix.linearAssignment()]
            distances[(referenceName, labelingName1)] = ClusterCenterDistances.calculateClusterDistances(referenceLabels,
                                                                                                         labeling1,
                                                                                                         dataset,
                                                                                                         metric=metric,
                                                                                                         confusionMatrix=confMatrix)
        else:
            for labelingName2 in fileList[count+1:]:
                labeling2 = Labeling(dataset)
                labeling2.labelRows(labelingName2)
                sys.stderr.write('working on: %s %s..\n'%(os.path.basename(labelingName1),
                                                          os.path.basename(labelingName2)))
                confMatrix = ConfusionMatrix.ConfusionMatrix()

                confMatrix.createConfusionMatrixFromLabeling(labeling1, labeling2)
                confMatScores[(labelingName1, referenceName)] = [confMatrix.NMI(),
                                                                 confMatrix.transposeNMI(),
                                                                 confMatrix.averageNMI(),
                                                                 confMatrix.linearAssignment()]
                distances[(labelingName1, labelingName2)] = ClusterCenterDistances.calculateClusterDistances(labeling1,
                                                                                                             labeling2,
                                                                                                             dataset,
                                                                                                             metric=metric,
                                                                                                             confusionMatrix=confMatrix)
                
        count +=1 


    # parse and print the output
    
    print "#file1\tfile2\tNMI\tTransposedNMI\tAverageNMI\tLinearAssignment\tdistance"
    for key in distances.keys():
        if not (opts.has_key('-f') or opts.has_key('--full-filename')):
            file1 = os.path.basename(key[0])
            file2 = os.path.basename(key[1])
        else:
            file1 = key[0]
            file2 = key[1]

        print "%s\t%s\t"%(file1, file2),

        for item in confMatScores[key]:
            print "\t%3.2f"%item,
        print "\t",
        if opts.has_key('--median'):
            print "%3.2f"%MLab.median(Numeric.array(distances[key].values())),
        elif opts.has_key('--mean'):
            print "%3.2f"%MLab.mean(Numeric.array(distances[key].values())),
        else:
            print "%3.2f+-%3.2f"%(MLab.mean(Numeric.array(distances[key].values())),
                                  MLab.std (Numeric.array(distances[key].values()))),

        print


if __name__ == "__main__":


    optTree, argv = WrapperUtil.createOptTree('d:m:f:hr:',['dataset-path=',
                                                           'metric',
                                                           'reference='
                                                           'full-filenames',
                                                           'input-file=',
                                                           'help']  ) 
                                        
    main(optTree, argv)



