########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
This module provides a SPING/PIDDLE drawing tool to graphically display the output of a
confusion matrix
"""

import colorsys

import sping
# these direct imports are the only things that need
# to change if you want to change the backend
from sping.PS import PSCanvas as Canvas
#from sping.PIL import PILCanvas as Canvas
#from sping.TK import TKCanvas as Canvas


import MLab
import Numeric

from compClust.score import ConfusionMatrix

def confusionMatrixWithResiduleHistograms(l1, l2, highlightLabeling =  None):

  """
  canvas = confusionMatrixWithResiduleHistograms(l1, l2):

  For coloring purposes...  l1 is alligned along the rows and l2 along
  the cols....  colors are normalized by the col sums.  (ie.  the
  reference should be l2)

  """

  cm = ConfusionMatrix()
  cm.createConfusionMatrixFromLabeling(l1, l2)

  originalMatrix = Numeric.array(cm.getCounts())
  
  # order the data by cluster sizes
  l1RowLabels = l1.getLabelByRows()
  l1Order = map(lambda x: (l1RowLabels.count(x), x), l1.getLabels())
  l1Order.sort()
  l1Order = map(lambda x: x[1], l1Order)
  l1Order.reverse()
  
  l2RowLabels = l2.getLabelByRows()
  l2Order = map(lambda x: (l2RowLabels.count(x), x), l2.getLabels())
  l2Order.sort()
  l2Order = map(lambda x: x[1], l2Order)
  # now generate the ordered confusion matrix

  m = []
  rowCount = 0
  for row in l1Order:
    m.append([])
    for col in l2Order:
      m[rowCount].append(originalMatrix[cm.rowClassNames[row], cm.colClassNames[col]])
    rowCount+=1

  matrix = Numeric.array(m)
  del(m)
  del(originalMatrix)
  
  # set up the basic dimensions
  rows,cols = matrix.shape  

  # estimate the right canvas size
  elementSize = 5
  xOrigin = 10
  yOrigin = 150
  width = elementSize*cols  + xOrigin + 150
  hieght = elementSize*rows + yOrigin + 150 

  canvas = Canvas((width,hieght))

  # draw the histogram

  totalSum = MLab.sum(MLab.sum(matrix))
  colSums = MLab.sum(matrix)
  rowSums = MLab.sum(MLab.transpose(matrix))
  for i in range(0, rows):
    barSize = (rowSums[i] / float(totalSum))*100
    canvas.drawRect(xOrigin+ (elementSize*cols+2),
                    yOrigin+(elementSize*i),
                    xOrigin+ (elementSize*cols) + barSize,
                    yOrigin+(elementSize*(i+1)),
                    fillColor=sping.colors.green,
                    edgeWidth = 0)

    
  for i in range(0, cols):
    barSize = (colSums[i] / float(totalSum))*100
    canvas.drawRect(xOrigin+(i*elementSize),
                    yOrigin-2,
                    xOrigin+((i+1)*elementSize),
                    yOrigin - barSize,
                    fillColor=sping.colors.purple,
                    edgeWidth = 0)
  

  # normalize counts by col sums
  matrix = matrix.astype(Numeric.Float)
  matrix = matrix / (Numeric.sum(matrix)+.001)
  #matrix = Numeric.transpose(Numeric.transpose(matrix) / Numeric.sum(matrix,1))
  # make the confusionMatrixGrid
  xpos = xOrigin
  ypos = yOrigin
  for r in range(0,rows):
    for c in range(0,cols):
      color = apply(sping.colors.Color,
                    colorsys.hsv_to_rgb(matrix[r,c]*.7, 1, 1))
      canvas.drawRect(xpos,ypos,
                      xpos+elementSize, ypos+elementSize,
                      edgeWidth=.3, fillColor=color)
      xpos = xpos + elementSize
    ypos = ypos + elementSize
    xpos = xOrigin

  # add the adjancy elements
  adjList = map(lambda pair :
                (l1Order.index(pair[0]), l2Order.index(pair[1])),
                cm.getAdjacencyList())

  for row,col in adjList:
    canvas.drawRect(xOrigin+(col*elementSize),
                    yOrigin+(row*elementSize),
                    xOrigin+((col+1)*elementSize),
                    yOrigin+((row+1)*elementSize),
                    edgeWidth= 2)


  if highlightLabeling is not None:

    # draw a circle with a fill color slightly
    # more gentel than it would otherwise be.

    indexes = []
    map(lambda label: indexes.extend(highlightLabeling.getRowsByLabel(label)), highlightLabeling.getLabels())
    
    for index in indexes:
      row, col = cm.findCellCoordinates(index)
      
      canvas.drawEllipse(xOrigin+(col*elementSize),
                         yOrigin+(row*elementSize),
                         xOrigin+((col+1)*elementSize),
                         yOrigin+((row+1)*elementSize),
                         edgeWidth= 1,
                         edgeColor = sping.colors.gray
                         )

  canvas.drawString("LA = %3.2f NMI = %3.2f NMI' = %3.2f"%(
    cm.linearAssignment(), cm.NMI(), cm.transposeNMI()), 5,10)

  return(canvas)
  
  
    
  
  
