########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
Written By: Christopher Hart

This module provides a general mechinism for calculating ROC-like analysis on lists.
"""

import Numeric
import MLab

from compClust.mlx import views
from compClust.mlx import labelings

from compClust.util import DistanceMetrics
from compClust.util import listOps

def rocCurve(orderedList, set1, set2):

  """
  returns 2 lists: the X and Y coordanates for creating an ROC curve on the
  input data.
  
  given an orderedList which contains items from set1 and set2, generate an ROC
  curve plotting list1 vs list2
  

  The area contained under this curve is easily calculated using MLab.trapz 
  
  example:
    
    orderedList = [1,1,0,0,1,1,0,1,0,0,0,1,0] set1 = [1] set2 = [0]

    ([0,1,2,2,3,4,],[0,0,0,])rocCurve(orderedList, set1, set2)
  """
  
  ## create 2 dictionarys for fast lookup for contents of list1 and list2
  
  set1Dict = {}
  map(set1Dict.setdefault, set1)
  set2Dict = {}
  map(set2Dict.setdefault, set2)
  
  # now calculate x and y coords for the ROC curve 
  
  xcoords = MLab.cumsum(map(set1Dict.has_key, orderedList))
  xcoords = xcoords/float(xcoords[-1])
  ycoords = MLab.cumsum(map(set2Dict.has_key, orderedList))
  ycoords = ycoords/float(ycoords[-1])

  return(xcoords, ycoords)

def clusterROC(dataset, labeling, label, distanceMetric=DistanceMetrics.EuclideanDistance):

  """
  returns the ROCarea, xcoords, ycoords for a ROC cuve given a dataset,
  labeling and a label from that labeling, return an ROC curve and area.
  DistanceMetric must be of the functional form:

  [list of distances] = distanceMetric(vector, array)
  
  Implimentations of several distance metrics can be found in compClust.util.DistanceMetrics
  
  """

  data = dataset.getData()
  clusterRows = labeling.getRowsByLabel(label)
  nonClusterRows = listOps.difference(range(dataset.getNumRows()), clusterRows)
  clusterMean = MLab.mean(Numeric.take(data, clusterRows))
  distances = distanceMetric(clusterMean, data)
  ranks = Numeric.argsort(distances)
  xcoords, ycoords = rocCurve(ranks, nonClusterRows, clusterRows)
  area = MLab.trapz(ycoords, xcoords)
  return(area, xcoords, ycoords)
 
def clusteringROC(dataset, labeling, distanceMetric=DistanceMetrics.EuclideanDistance):

  """
  given a dataset and a labeling return a dictionary with ROC statistics for
  each cluster in the daaset. 
  """
  rocStats = {}
  for label in listOps.unique(labeling.getLabelByRows()):
    rocStats[label] = clusterROC(dataset, labeling, label, distanceMetric=distanceMetric)
  return(rocStats)
 
def interclusterROC(dataset, labeling1, label1, labeling2, label2, distanceMetric=DistanceMetrics.EuclideanDistance):
   
   """
   Given a dataset and a label from labeling1 and label from labeling2 calculate the ROC curve 
   between the two sets bound to labe1 and label2
   """
   
   allData = dataset.getData()
   cluster1Rows = labeling1.getRowsByLabel(label1)
   cluster2Rows = labeling2.getRowsByLabel(label2)
   cluster1Mean = MLab.mean(Numeric.take(allData, cluster1Rows))
   unionData = Numeric.take(allData, cluster1Rows + cluster2Rows) 
   distances = distanceMetric(cluster1Mean, unionData)
   ranks = Numeric.argsort(distances)
   n1 = len(cluster1Rows)
   n2 = len(cluster2Rows)
   xcoords, ycoords = rocCurve(ranks, range(n2), range(n2,n1+n2))
   area = MLab.trapz(ycoords, xcoords)
   return(area, xcoords, ycoords)
