import colorsys


#import matplotlib.numerix as nx
from matplotlib.numerix import mlab

from IPlot import IPlot
from compClust.iplot.mappers.Mapper import scaleList

class HistogramPlotter(IPlot):

  """
  A nice interactive histogram tool.
  """
  def __init__(self, canvasFactory, labeling=None, seriesName=None, primaryLabeling=None, secondaryLabeling=None, sortBy = 'labels', plot = 1, axis=None):  
    IPlot.__init__(self, canvasFactory, axis)
    # FIXME: port
    #self.configure(barmode="aligned")
    self.primaryLabeling=primaryLabeling
    self.secondaryLabeling=secondaryLabeling
    self.labelings = [] 
    self.seriesNames= []
    # FIXME: port
    #self.element_bind('all', '<Button-2>', self.fullPlot)
    if labeling:
      self.addLabeling(labeling, seriesName)
      self.plot(sortBy=sortBy)
    
  def addLabeling(self, labeling, seriesName=None):
    self.labelings.append(labeling)
    self.seriesNames.append(seriesName)

  def plot(self, sortBy='labels', fixedColor=None, pack=1, alpha=.7):

    """
    plot(self, sortXBy='labels')

    sortXBy can be any of these: ['labels', 'size']
    alpha is how transparent the bars should be
    """

    values = scaleList(range(len(self.labelings)),
                       minReturn=0, maxReturn=.7)

    colors=   map(colorsys.hsv_to_rgb,
                  values, [1]*len(values), [1]*len(values))
      
     
    count = 0
    for labeling, name, color in zip( self.labelings, self.seriesNames, colors):
      self.xlabels = labeling.getLabels()
      ydata = map(len, map(labeling.getRowsByLabel,  self.xlabels))
      if fixedColor is not None:
        color = fixedColor

      if sortBy == 'labels':
        data = zip(self.xlabels, ydata)
        data.sort()
        self.xlabels = map(lambda x: x[0], data)
        ydata   = map(lambda x: x[1], data)

      elif sortBy == 'size':
        data = zip(ydata, self.xlabels)
        data.sort()
        ydata   = map(lambda x: x[0], data)
        self.xlabels = map(lambda x: x[1], data)

      ydata   = tuple(ydata)
      self.xlabels = tuple(self.xlabels)
      xdata   = tuple(range(len(self.xlabels)))
      if name is None:
        name = "%s: %s"%(str(count), str(labeling.getName()))
      else:
        name = "%s: %s"%(str(count), name)

      bars = self.axis.bar(xdata, ydata, color=color)
      for b in bars:
        b.set_alpha(alpha)

      count +=1
      try:
        tickLabels = ["%3.2f"%(self.xlabels[index]) for index in xdata]
      except:
        tickLabels = [str(self.xlabels[index]) for index in xdata]
      
      tick_labels = self.axis.set_xticklabels(tickLabels)
      if mlab.max(map(len, tickLabels)) > 4:
        tickrotation='vertical'
      else:
        tickrotation='horizontal'

      for tick in tick_labels:
        tick.set_rotation(tickrotation)

      
  def fullPlot(self, event):
    """
    fullPlot(event, label)

    show the full row-wise plot of the cluster pointed to in the mean
    dataset at row
    """
    global ev
    ev = event
    win = Tkinter.Toplevel()
    index = event.widget.element_closest(event.x, event.y)['index']
    x = event.widget.element_closest (event.x, event.y)['x']
    y = event.widget.element_closest (event.x, event.y)['y']
    name   = event.widget.element_closest (event.x, event.y)['name']

    print name.split(": ")[1]
    labeling = self.labelings[int(name.split(": ")[0])]
    dataset = labeling.getDataset()
    label = self.xlabels[index]

    print "printing %s"%(str(label))
    subset = subsetByLabeling(dataset, labeling, label)

    dp = DatasetPlot(master=win)
    dp.configure(title = "Group %s\n from %s \n %i elements"%(str(label),name,subset.getNumRows()))
    rv = DatasetRowPlotView(subset, primaryLabeling=self.primaryLabeling, 
                                    secondaryLabeling=self.secondaryLabeling)
    dp.plot(rv)
    self.selectedPlot = dp
    self.selectedView = rv
    dataset.removeView(subset)
