"""
A fresh implimentation of the confusion matrix using current CompClust/MLX APIs

Written By: Christopher Hart

"""
import re
import sys
import string
import types

import Numeric

from compClust.util import listOps
from compClust.util.FileIO import _parseToken
from compClust.mlx import labelings
from compClust.mlx import labelings

from compClust.score.NMI import NMI
from compClust.score.NMI import transposeNMI
from compClust.score.NMI import averageNMI
from compClust.score.LinearAssignment import LinearAssignment
from compClust.mlx import graphmatching


class ConfusionMatrix:
  """
  A new implimentation of the Confusion Matrix, based on labelings.  This
  maintains partial compatibility with the old API
  """

  def __init__(self, labelingList, clusterOrders=None, web_safe=False):
    """
    given either a list of labelings or several labelings constuct a confusion matrix
    based on those labelings.  By default this construct a global labeling on the 
    dataset.
    """
    self.__confusionLabeling = buildConfusionLabeling(labelingList, web_safe)
    if clusterOrders is None:
      self.__clusterOrders =  [l.getLabels() for l in labelingList ]
    else:
      self.__clusterOrders = clusterOrders
    self.__dataset = labelingList[0].getDataset()
    self.__clustOrders = None
    self.__is_web_safe = web_safe
 
  def __get_is_web_safe(self):
    return self.__is_web_safe
  isWebSafe = property(__get_is_web_safe, doc="Indicate if our labeling was generated in 'web_safe' mode")
  
  ## getter and setters 
  def setClusterOrder(self, clusterOrders):
    self.__clusterOrders = clusterOrders
  
  def setDataset(self, dataset):
    raise NotImplementedError()
    
  def getDataset (self):
    """return dataset we're attached to"""
    return(self.__dataset)
  dataset = property(getDataset, doc="return dataset we're attached to")
  
  def getConfusionLabeling(self):
    return(self.__confusionLabeling)

  def getClusterOrders (self):
    return(self.__clusterOrders)
 
  def getCounts (self):
    """
    DEPRECATED - use getMatrix
    """
    sys.stderr.write('DEPRECATED - use getMatrix instead')
    return(self.getMatrix())
    
  def getMatrix (self):
    """ return the confusion matrix itself """
    return(buildConfusionMatrix(self.__confusionLabeling,
                                self.__clusterOrders))
  
  ## scoreing functions  (only good if 1-1 labelings)
  def NMI(self):
    """ Return the Normalized Mutual Information score """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('NMI only supported with 2d confusion matrices')
    return(NMI(m))
  
  def transposeNMI(self):
    """ Return the Transpose NMI score """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('NMI only supported with 2d confusion matrices')
    return(transposeNMI(m))
  
  def averageNMI(self):
    """ Return the average NMI score between the NMI and transpose NMI scores """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('NMI only supported with 2d confusion matrices')
    return(averageNMI(m))
  
  def linearAssignment(self):
    """ Return the Linear Assignment Value of the matrix """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('Linear Assignment only supported with 2d confusion matrices')
    return(LinearAssignment(m))
  
  def getAdjacencyMatrix(self):
    """
    Return a matrix which describes the optimal cluster pairings
    """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('Adjancecy calculation only supported with 2d confusion matrices')
    score, adj_matrix = graphmatching.match(self.getMatrix())
    return adj_matrix
      
  def getAdjacencyList(self):
    """
    return a list of tuples (ie Confusion matrix cells) that are the optimal cluster pairings
    """
    m = self.getMatrix()
    if len(Numeric.shape(m)) > 2:
      raise ValueError ('Adjancecy calculation only supported with 2d confusion matrices')
    
    adjMatrix = self.getAdjacencyMatrix()
    adjList = []
    for r in range(0, Numeric.shape(adjMatrix)[0]):
      for c in range(0, Numeric.shape(adjMatrix)[1]):
        if adjMatrix[r][c] > 0:  
          adjList.append((self.__clusterOrders[0][r],self.__clusterOrders[1][c]))
    return(adjList)






def buildConfusionLabeling(labelingList, web_safe, glabeling=0):
  """
  Given a list of labelings, build a confusion matrix or "hyper-confusion" matrix labeling
  """
  ds = labelingList[0].getDataset() 
  if not listOps.allTrue(map(lambda x: x.getDataset() == ds , labelingList)):
    raise ValueError('labelings not all attached to same dataset')
 
  is_row_labeling = labelingList[0].isRowLabeling()
  if is_row_labeling:
    labeling_len = len(labelingList[0].getAllRowLabels())
    len_func_name = "getAllRowLabels"
  else:
    labeling_len = len(labelingList[0].getAllColLabels())
    len_func_name = "getAllColLabels"

  for labels in labelingList[1:]:
    # FIXME: once we have isRow/ColLabeling working this should be updated to use it    
    #if labels.isRowLabeling() != is_row_labeling:
    #  raise ValueError("%s [%s] was not the same labeling type as %s [%s]" %(str(labels),labels.isRowLabeling(),
    #                                                                         labelingList[0], is_row_labeling))
    if len(getattr(labels, len_func_name)()) != labeling_len:
      raise ValueError("%s had %d labels, while %s had %d" %(str(labels),len(getattr(labels, len_func_name)()),
                                                             labelingList[0],labeling_len))  
  
  ## create a list of all  possible combinations of labelings (ie the cartisean product )
  labels =  reduce(listOps.cartesian, [l.getLabels() for l in labelingList])

  ## now to get rid of the [[[a,b],c],d] phonomenum
  cmLabels = map(tuple, listOps.fullCross([l.getLabels() for l in labelingList]))
      
  cm_labeling_name = ('ConfMat', tuple(labelingList))
  if glabeling:
    cmLab = labelings.GlobalWrapper(ds, cm_labeling_name)
  else:
    cmLab = labelings.Labeling(ds, cm_labeling_name)

  for cmLabel in cmLabels:
    keys = []
    for lab, l in zip(labelingList,cmLabel):
      keys.append(lab.getKeysByLabel(l))
    if web_safe:
      cmLabel = tuple_stringify(cmLabel)
    cmLab.addLabelToKeys(cmLabel,  reduce(listOps.intersection, keys))

  cmLab.labelingList = labelingList
  return(cmLab)

def tuple_stringify(label):
  """Convert a python tuple into something that can be stuck on a url
  """
  return re.sub(", ", ",", str(label))

def tuple_unstringify(label):
  """convert one of our stringified tuples back into a tuple
  """
  def filter_unsafe_chars(s):
    return re.sub("[^0-9.a-zA-Z ]","", s)
  label = label.strip()
  assert label[0] == '(' and label[-1] == ')'
  label = label[1:-1].split(',')
  label = tuple([ _parseToken(filter_unsafe_chars(l)) for l in label ])
  return label

def buildConfusionMatrix(confusionLabeling, clusterOrders=None):
  """
  Given a confusion labeling, return a confusion array based on the
  confusionLabeling and clusterOrders. By default the order of the array will
  parallel the order of labelings in confusion labeling and the dimensionallity
  will be equal to the dimensionallity of the confusion labeling.  dims is list
  of labelings to use for ording.
  """
  
  if clusterOrders is None:
    clusterOrders = [l.getLabels() for l in confusionLabeling.getName()[1]]
  dims = tuple(map(len, clusterOrders))
  cm = Numeric.zeros(dims)
  for label in confusionLabeling.getLabels():
    if type(label) in types.StringTypes:
      clusterOrderAndLabelList = zip(clusterOrders, tuple_unstringify(label))
    else:
      clusterOrderAndLabelList = zip(clusterOrders, label)
    
    index = tuple([order.index(cluster) for order,cluster in clusterOrderAndLabelList])
    
    cm[index] = len(confusionLabeling.getKeysByLabel(label))
  
  assert Numeric.sum( cm.flat) <= confusionLabeling.getDataset().getNumRows()
  return(cm)



