########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
A simple pychart based summary plotter.

"""

import colorsys
import string
import tempfile
import os
import copy
import Tkinter
import re
import types

import Pmw
import MLab
import Numeric
import gracePlot
from Scientific.Functions import LeastSquares


from compClust.score import roc
from compClust.visualize import plot

from compClust.mlx.datasets import Dataset
from compClust.mlx.labelings import Labeling
from compClust.mlx.labelings import subsetByLabeling
from compClust.mlx.views import RowPCAView, RowSubsetView

from compClust.util import DistanceMetrics
from compClust.util import InterpreterTools
from compClust.util import NaN
from compClust.visualize import IPlot


def estimateLD(matrixSpectrum):

  """
  estimateLD(s)  where s is the variances from an svd decompistion
  """
  
  maxVariance = matrixSpectrum[0]
  dims = len(matrixSpectrum)
  estK, fit = LeastSquares.leastSquaresFit(lambda p, x: (1/(x**p[0]))*maxVariance,
                                           (.5,),
                                           zip(range(1,dims+1), matrixSpectrum))

  normalizedShannonEntropy = -1*(MLab.sum(matrixSpectrum*MLab.log(matrixSpectrum))
                                 / dims)
    
  return(estK[0], maxVariance, fit, matrixSpectrum, normalizedShannonEntropy)


def visualizeLD(matrixSpectrum, name=None, g=None):

  dims = len(matrixSpectrum)
  estK, maxVariance, fit, matrixSpectrum, entropy = estimateLD(matrixSpectrum)
  estSpectrum = map(lambda x: (1.0/(x**estK))*maxVariance, range(1,dims+1))

  if g is None:
    g = gracePlot.gracePlot()
  
  g.plot(matrixSpectrum , linetype = 'none')
  g.hold(1)
  g.plot(estSpectrum)
  if name is None:
    g.title('Variance vs Number of Dimensions')
  else:
    g.title('%s: Variance vs Num Dims'%(name))

  g.subtitle('Maxvar = %3.2f, Linear Dependence = %3.2f, Entropy = %3.2f'%(maxVariance, estK, entropy))
  g.xlabel('Number of Dimensions')
  g.ylabel('Variance')
  g.legend(['Real Spectrum', 'x\S-%3.2f \N fit = %3.4f'%(estK, fit)])
  
  return(g)



def pca(dataset, labeling=None, plot=None):

  """
  just a quick little pca plotter
  """

  if plot is None:
    g = gracePlot.gracePlot()
  else:
    g = plot

  pcaDataset = RowPCAView(dataset)
  variances = pcaDataset.getVariances()
  colVariances = Labeling(pcaDataset)
  map(colVariances.addLabelToCol, variances, range(pcaDataset.getNumCols()))

  variances.sort()
  variances.reverse()
  
  if labeling is None:
    # We don't have cluster labels make everything identicle
    xdata = pcaDataset.getColData(colVariances.getColsByLabel(variances[0])[0])
    ydata = pcaDataset.getColData(colVariances.getColsByLabel(variances[1])[0])
    g.plot(xdata, ydata, linetype='none')
    
  else:
    # we have cluster labels
    g.hold()
    for cluster in labeling.getLabels():
      print cluster
      # this is a hack to make it work untill I can figure out
      # why the subseting on the pca view to work
      xdata = copy.copy(RowSubsetView(pcaDataset, labeling.getRowsByLabel(cluster)).getColData(colVariances.getColsByLabel(variances[0])[0]))
      ydata = copy.copy(RowSubsetView(pcaDataset, labeling.getRowsByLabel(cluster)).getColData(colVariances.getColsByLabel(variances[1])[0]))
      g.plot(xdata, ydata, linetype='none')

  print colVariances.getLabelsByCol(0)
  g.xlabel('1st PC- %3.2f variance captured'%(variances[0]))
  g.ylabel('2nd PC- %3.2f variance captured'%(variances[1]))
  g.title('2D-PCA projection')
  
  g.multi(2,1)
  g.focus(1,0)
  print variances
  visualizeLD(variances, g=g)

  return(g)


def iPCA(ds, labeling=None, annotes=None):

  """
  generate an interactive PCA plot
  """
  removeAnnotes= 0
  if annotes is None:
    annotes = Labeling(ds, 'tmp')
    annotes.labelRows(range(0, ds.numRows))
    removeAnnotes = 1
  
  def drawLine(event, labels):
    index = event.widget.element_closest(event.x, event.y)['index']
    IPlot.plot(ds.getRowData(index))

  def summaryPlot(event):
    g = event.widget
    pos = "@" +str(event.x) +"," +str(event.y)
    legend = g.legend_get(pos)    # get the selected cluster
    print legend
    cluster = subsetByLabeling(ds, labeling, legend)
    clusterAnnots = Labeling(cluster)
    clusterAnnotes.labelFrom(annotes)
    g = IPlot.plot(cluster.getData(), setLabel=clusterAnnotes.getLabelByRows())
    g.legend_configure(hide=1)
    IPlot.hold(0)
    ds.removeView(cluster)
    del(cluster)
    
  pcaDs = RowPCAView(ds)
  pcaLabs = Labeling(pcaDs)
  pcaLabs.labelFrom(labeling)
  if labeling is not None:
    for label in labeling.getLabels():
      #cluster = pcaDs.subsetRows(labeling.getRowsByLabel( label))
      cluster = subsetByLabeling(pcaDs, pcaLabs,  label)
      g = IPlot.plot(cluster.getColData(0), cluster.getColData(1),
                     pointLabels = annotes.getLabelByRows(),
                     setType='scatter', function2=drawLine  )
      pcaDs.removeView(cluster)
      IPlot.hold(1)
  else:
    g = IPlot.plot (pcaDs.getColData (0), pcaDs.getColData (1), setType='scatter', pointLabels = annotes.getLabelByRows(), function2=drawLine )

  g.legend_bind("all", "<Button-3>", summaryPlot)
  IPlot.hold(0)

  if removeAnnotes:
    ds.removeLabeling(annotes)
  pcaDs.removeLabeling(pcaLabs)
  ds.removeView(pcaDs)
  

  return(g)




def tragectorySummary(ds, labeling, annotes=None, clusters=None, computeROC=1):

  """
  This creates a set of thumbnail size trajectories
  """
  
  def summaryPlot(event, labels):

    global ev
    ev= event

    try:
      g = Pmw.Blt.Graph()
    except:
      reload(Pmw.Blt)
      g = Pmw.Blt.Graph()
    
    row = int(event.widget.grid_info()['row'])
    col = int(event.widget.grid_info()['column'])
    label = re.sub('^label ', '', event.widget.winfo_name())
    print "working on plotting cluster %s..."%(label)
    cluster = subsetByLabeling(ds, labeling, label)

    if annotes  is not None:
      clusterAnnotes = Labeling(cluster)
      clusterAnnotes.labelFrom(annotes)
      g = IPlot.plot(cluster.getData(), setLabel=clusterAnnotes.getLabelByRows())
    else:
      g = IPlot.plot(cluster.getData())
    g.legend_configure(hide=1)
    g.configure(title='Cluster %s'%(label))
    IPlot.hold(0)
    ds.removeView(cluster)
    del(cluster)

  # this seems like a reasonable default settings
  numCols = 2
  plotSize = (200,200)
  

  # this sets up the scrolled frame
  r = Tkinter.Tk()
  master = Pmw.ScrolledFrame(r)
  master.pack(expand='y', fill='both')


  # this sets up the dataset/labeling stuff
  if clusters is None:
    labels = labeling.getLabels()
  else:
    labels = clusters
  tmp = zip(map(lambda x: len(labeling.getRowsByLabel(x)),  labels), labels)
  tmp.sort()
  tmp.reverse()
  labels = map(lambda x: x[1], tmp)
  k= len(labels)
  numRows = MLab.ceil(float(k)/numCols)
  for count in range(k):
    col = count%numCols
    row = int(MLab.ceil(count/numCols))
    label = labels[count]
    print "working on cluster %s"%(label)
    print "\t Subseting Cluster"
    cluster = subsetByLabeling(ds, labeling, label)
    clusterData = cluster.getData()
    print "\t Caclulating Mean/Std Cluster"
    dataMean = MLab.mean(clusterData)
    try:
      dataStd = MLab.std(clusterData)
    except:
      dataStd = MLab.zeros(cluster.getNumCols())
        
    print "\t Generating Plots"
    frm = Tkinter.Frame(master.interior())
    frm.grid(row=row, col=col)
    #g = IPlot.plot(Numeric.array([dataMean+dataStd,
    #                              dataMean,
    #                              dataMean-dataStd]),
    #               master = frm)
    try:
      g = Pmw.Blt.Graph(master=frm)
    except:
      reload(Pmw.Blt)
      g= Pmw.Blt.Graph(master=frm)
    g.line_create('+std', xdata=tuple(range(dataMean.shape[0])), ydata=tuple(dataMean+dataStd),
                  color='red', pixels = "0.04i", dashes=(5,1))
    g.line_create('mean', xdata=tuple(range(dataMean.shape[0])), ydata=tuple(dataMean),  color='blue', pixels = "0.04i" )
    g.line_create('-std', xdata=tuple(range(dataMean.shape[0])), ydata=tuple(dataMean-dataStd),
                  color='red', pixels = "0.04i", dashes=(5,1))

    g.legend_configure(hide=1)
    g.configure(width=plotSize[0],height=plotSize[1])
    g.grid_on()
    g.grid(row=0, col=0)
    
    if computeROC and cluster.getNumRows > 3:
      print "\t Computing ROC"
      try:
        rocArea = roc.computeRocForLabel(label, labeling, ds)['area']
      except:
        rocArea = NaN.nan
    else:
      rocArea = NaN.nan
    summaryText = "Cluster %s\n\n  Size: %i\n  ROC Area: %3.2f\n"%(label, cluster.getNumRows(), rocArea)
    button = Tkinter.Button(frm, text=summaryText, width=20, name="label "+label)
    button.grid(row=0,col=1)
    #IPlot.hold(0)
    button.bind('<Button-1>', lambda e: summaryPlot(e, labels))
    count +=1 

    del(clusterData)
    cluster.detatch()
    del(cluster)

  return(master)
      

def fitnessTable(fitnessTable, k='k'):

  """
  summarizeFitnessTable(fitnessTable, k)
  
  where k is either 'k' or 'k-prime'

  ie) mccvAlgo.getFitnessTable()

  """
  # fitnessTable = MCCVAlgo.getFitnessTable()
  if type(fitnessTable) == types.StringType:
    f = open(fitnessTable,'r')
    best = f.readline()
    data=[]
    for line in f.readlines():
      data.append(map(float, line.split()))
    fitnessTable = Numeric.array(data)
  
  if k=='k':
    KIndex=0
  else:
    KIndex=1
  
  data = {}
  for row in fitnessTable:
    if data.has_key(row[KIndex]):
      data[row[KIndex]].append(row[2])
    else:
      data[row[KIndex]]=[row[2]]


  X = data.keys()
  X.sort()
  Y = map(data.get, X)
  yMean = MLab.mean(Y,1)
  try:
    yStd = MLab.std(Y,1)
  except:
    yStd = MLab.zeros(len(X))
    

  g = IPlot.IPlot()

  X = tuple(X)
  g.line_create('mean',
                xdata=X, ydata=tuple(yMean),
                yerror =  tuple(yStd),
                showerrorbars = 'y',
                color='blue', pixels = "0.04i" )

  g.marker_create("text", name="bestK")
  try:
    g.marker_configure("bestK",
                       coords=(MCCVAlgo.getBestParam(), MLab.mean(data[MCCVAlgo.getBestParam()])),
                       text="%i"%(MCCVAlgo.getBestParam()),
                       background="lightblue")
  except:
    pass


  g.grid_on()
  g.pack(expand=1, fill='both')  
  return(g)
  


      
