###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2005 by:                                                 #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
###########################################################################
#
#          Authors: Brandon King, Joe Roden & Diane Trout
# $LastChangedDate: 2005-12-12 16:29:00 -0800 (Mon, 12 Dec 2005) $
#        $Revision: 1410 $
#
__version__  = '1.2'
__revision__ = '$Rev: 1410 $'
__date__     = '$LastChangedDate: 2005-12-12 16:29:00 -0800 (Mon, 12 Dec 2005) $'



from compClust.score.ConfusionMatrix2 import ConfusionMatrix

"""
Functions used to analyze and report on sets of ColumnScore objects.

A ColumnScore object attempts to measure the degree of correlation between one
discrete column partition and another, or between a discrete column partitioning
and a numeric column (covariate) labeling.

This code was instigated by the pcaGinzu work, in which we needed to correlate 
pc-generated column partitions (Up,Flat,Down) versus a number of continuous or
discrete covariates.  We realize this work has more general applications,
and so we plan to refactor portions of this module and pcaGinzu's ColumnScore
class into some separate, more general class of functions.  But, pressed for
time, this code was added as-is to support an upcoming software release.

"""


def minSort(item1, item2):
  if item1.min_score() > item2.min_score():
    return 1
  elif item1.min_score() == item2.min_score():
    return 0
  else:
    return -1
  

def getAllResults(p, pcMin, pcMax):
  resultDict = {}
  for pcNum in range(pcMin, pcMax+1):
    resultDict[pcNum] = p.scoreColumnLabelingsForPCN(pcNum)
    resultDict[pcNum].sort(minSort)

  return resultDict


def generateResultsTable(resultDict, cutoff=None):
  tableDict = {}

  for pcNum in resultDict:

    if not tableDict.has_key(pcNum):
      tableDict[pcNum] = {}

    scoreList = resultDict[pcNum]

    for scoreObj in scoreList:

      covariateName = scoreObj.labeling.getName()

      if tableDict[pcNum].has_key(covariateName):
        raise ValueError, 'duplicate! --> %s' % (covariateName)

      min_score = scoreObj.min_score()

      if cutoff is not None:
        if min_score <= cutoff:
          tableDict[pcNum][covariateName] = min_score
        else:
          tableDict[pcNum][covariateName] = -1
      else:
        tableDict[pcNum][covariateName] = scoreObj.min_score()

  return tableDict


def writeResultsTableDictToFile(filePath, tableDict, covariates=None):

  f = open(filePath, 'w')

  pcNumList = tableDict.keys()
  pcNumList.sort()

  #Use all covariates if None is supplied.
  if covariates is None:
    pcDict = tableDict[tableDict.keys()[0]]
    covariates = pcDict.keys()
    covariates.sort()

  #Add header row
  headerRow = ['']
  for covName in covariates:
    headerRow.append(covName)
  f.write('\t'.join(headerRow)+'\n')

  #Process each pc
  for pcNum in pcNumList:
    pcDict = tableDict[pcNum]

    pcRow = ['PC%s' % (pcNum)]
    
    #Process each covariate in pc
    for covName in covariates:
      assert pcDict.has_key(covName)
      pcRow.append('%s' % (pcDict[covName]))

    f.write('\t'.join(pcRow)+'\n')

  f.close()
  


def summarizeContinuousResults(resultDict, continuousCutoff=0.05):
  summaryList = []
  for pcNum, scoreList in resultDict.items():
    row = [pcNum]

    for score in scoreList:
      if score.is_discrete:
        continue
      if score.min_score() <= continuousCutoff:
        row.append('%s|%s' % (score.labeling.getName(), score.min_score()))
    summaryList.append(row)

  return summaryList


def summarizeDiscreteResults(resultDict, discreteCutoff=0.8):
  summaryList = []
  for pcNum, scoreList in resultDict.items():
    row = [pcNum]

    for score in scoreList:
      if score.is_discrete:
        if score.scores >= discreteCutoff:
          row.append('%s|%s' % (score.labeling.getName(), score.scores))
    summaryList.append(row)

  return summaryList


def writeSummaryResultsToFile(filePath, summaryList):
  f = open(filePath, 'w')

  for summary in summaryList:
    f.write( '\t'.join( [ str(i) for i in summary ] )+'\n' )

  f.close()


def displayResults(scoreList, cutoff=0.05):

  for score in scoreList:
    if score.min_score() <= cutoff:
      print '%s: %s' % (score.labeling.getName(), score.min_score())


def displayScoresForColLabeling(labelingName, resultDict):

  for key, value in resultDict.items():
    results = [ v for v in value if v.labeling.getName() == labelingName ]
    assert len(results) == 1

    score = results[0]

    print key, str(score.scores)
  


def getBestResultForScore(labelingName, resultDict):

  best_pcn = None
  best_score = None

  for key, value in resultDict.items():
    results = [ v for v in value if v.labeling.getName() == labelingName ]
    assert len(results) == 1

    score = results[0]

    if score.is_discrete:
      if best_score is None:
        best_score = score.scores
        best_pcn = key
      elif score.scores > best_score:
        best_score = score.scores
        best_pcn = key
    else:
      if best_score is None:
        best_score = score.min_score()
        best_pcn = key
      elif score.min_score() < best_score:
        best_score = score.min_score()
        best_pcn = key

  return best_pcn, best_score


def getBestResultScoreForLabelings(labelingList, resultDict):
  resultList = []
  for l in labelingList:
    pcNum, score = getBestResultForScore(l, resultDict)
    resultList.append( (l, pcNum, score) )

  return resultList


def writeBestResultScoresToFile(filePath, resultList):
  f = open(filePath, 'w')

  for labelName, pcNum, score in resultList:
    f.write('%s\t%s\t%s\n' % (labelName, pcNum, score))

  f.close()


def getSignificantResultsForLabeling(labelingName,
                                     resultDict,
                                     cutoff=0.05,
                                     nmiThresh=None):

  resultList = []

  for key, value in resultDict.items():
    results = [ v for v in value if v.labeling.getName() == labelingName ]
    assert len(results) == 1

    score = results[0]

    if score.is_discrete:
      #print 'Skipping discrete value.'
      continue
    else:
      if score.min_score() <= cutoff:
        resultList.append( (key, score.min_score() ))

  return resultList


def getSignificantResultsForLabelingList(labelingList,
                                         resultDict,
                                         cutoff=0.05,
                                         nmiThresh=None):
  resultList = []
  for l in labelingList:
    scoreList = getSignificantResultsForLabeling(l, resultDict, cutoff, nmiThresh)
    print '-->', l
    print '  -->', scoreList
    row = [l]
    for pcNum, score in scoreList:
      row.append('%s|%s' % (pcNum,score))
    resultList.append( row )

  return resultList


def writeSignificantResultsToFile(filePath, resultList):
  f = open(filePath, 'w')

  for result in resultList:
    f.write( '\t'.join( [ str(r) for r in result ])+'\n' )

  f.close()


