########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################
#
#  Written By    :  Christopher Hart, Diane Trout, Lucas Scharenbroich
#  Date          :  Febuary 2001
#  Last Modified :  Aug     2001
#

from Numeric import *

def NMI(numericConfMatrix, trans=0):
  """
  Returns the NMI score for confMatrix.  The trans option is only used
  by the transposeNMI function.  The cols are assumed to contain the
  ground truth in the calculation.  For more information on NMI see:
  
     @article   {Forbes95,
     author   = {FORBES, {A.D.}},
     title    = {CLASSIFICATION-ALGORITHM EVALUATION - 5
                 PERFORMANCE-MEASURES BASED ON CONFUSION MATRICES},
     journal  = {JOURNAL OF CLINICAL MONITORING},
     volume   = {11},
     number   = {3},
     year     = {1995},
     pages    = {189--206},
        
  REQUIRES:  the numeric module
  """

  #
  # create a numeric array (matrix) we can work with
  #

  #numericConfMatrix = array(confMatrix.getHypercubeCounts(), Float64)
  if trans == 1:
    numericConfMatrix = transpose(numericConfMatrix)

  #
  # yates cotinuity correction.  This avoids singular values.
  #

  numericConfMatrix = numericConfMatrix+ .000000005    

  # setting up some basic variables
  # the total number of data points

  N = sum(sum(numericConfMatrix)) 
  colSum = sum(numericConfMatrix)
  rowSum = sum(transpose(numericConfMatrix))

  #
  # H_r the information contained in the rows
  #

  H_r = -sum(multiply((rowSum/N), log(rowSum/N)))

  #
  # H_s the information contained in the cols
  # this is usually assumed to be where the
  # ground truth should be if there is one
  #
  
  H_s = -sum(multiply((colSum/N),log(colSum/N)))

  #
  # the mutual information shared between the clustering
  # algorithm and the ground truth
  #
  
  H_sr = -sum(sum(multiply(numericConfMatrix/N, log(numericConfMatrix/ N))))

  #
  # Now we can caculate and return the NMI score
  #

  try:
    nmi = 1 - ( (H_sr - H_r) / H_s)
  except:

    #
    # return -0 if the H_s is undefined (as is the case when the column
    # clustering has only 1 class).
    nmi = -0.0

  return nmi

def averageNMI(confMatrix):
  """
  Returns the average NMI score between the confusion matrix and its
  transpose.
  """

  return 0.5 * (NMI(confMatrix) + transposeNMI(confMatrix))

def transposeNMI(confMatrix):
  """
  Returns the NMI score of the transposed confusion matrix.
  """
  
  return NMI(confMatrix, trans=1)
