import colorsys

import matplotlib.numerix as nx

from compClust.iplot.mappers.Mapper import Mapper, scaleList
from compClust.iplot.views import IPlotView

__docformat__ = "restructuredtext en"

###############
# Internal color model conversions
def __rgba_to_hsv(color):
  return list(colorsys.rgb_to_hsv(*color[:3]))
def __rgba_to_rgb(color):
  return color[:3]
def __rgba_to_rgba(color):
  if len(color) != 4:
    raise ValueError("RGBA colors must be 4 values")
  return list(color)
def __hsv_to_rgba(color):
  return list(colorsys.hsv_to_rgb(*color))+[1]
def __rgb_to_rgba(color):
  return list(color) + [1]

# the list of supported color models
colorModels = ('rgb','rgba','hsv')
# allow doing color conversions by table lookup
toRGBAModel = {'hsv': __hsv_to_rgba,
               'rgb': __rgb_to_rgba,
               'rgba': __rgba_to_rgba}
fromRGBAModel = {'hsv': __rgba_to_hsv,
                 'rgb': __rgba_to_rgb,
                 'rgba': __rgba_to_rgba}  

class ColorMapper(Mapper):

  """
  This is a support class for the IPlotView class which descrbes
  how the elements in the view should be colored.  
  """
  # FIXME: Should numColors be an axis instead to allow us to pick if we want to look at rows or cols?
  def __init__(self, plotView, numColors):
    super(ColorMapper, self).__init__(plotView)
    # allocate the matrix to store our RGBA float values
    self.__numColors = numColors
    self.__colors = None
    
  def __getitem__(self, index):
    """Return the rgba color from a specific row
    """
    if self.__colors is None:
      return None
    return self.__colors[index]
  
  def __setitem__(self, index, value):
    """Set an RGBA color for a particular color index
    """
    # make sure we're getting float colors
    if len(value) != 4:
      raise ValueError("Base colors must contain 4 values")
    for c in value:
      if not (0 <= c <= 1):
        raise ValueError("Colors must be in the range [0,1]")
    # Create colors list if its currently none
    if self.__colors is None:
      self.__colors = [[0, 0, 0, 0] for x in xrange(self.__numColors) ]
    self.__colors[index] == value
    
  def __len__(self):
    if self.__colors is None:
      return 0
    else:
      return self.__numColors
    
  def setColor(self, index, value, colorModel='rgba'):
    """
    Set a particular color index to the color specified by colorModel
    
    :Parameters:
      - `index`: the index into the color array
      - `value`: the color (which will be interpreted according to colorModel)
      - `colorModel`: which color model to use for the above color value
    """
    self[index] = toRGBAModel[colorModel](value)
      
  def getColor(self, index, colorModel='rgba'):
    """
    Return the color for index converting it to colorModel if needed
    
    :Parameters:
      - `index`: the index into the color array
      - `colorModel`: which color model to use for the above color value
    """
    return fromRGBAModel[colorModel](self[index])
    
  def setColors(self, values, colorModel='rgba'):
    """
    """
    if values is None:
      self.__colors = None
    else:
      if len(values) != self.__numColors:
        raise ValueError("Color list had %d elements instead of %d elments" % \
                        (len(values), self.__numColors))
      # cut out the dictionary lookup in the list comprehension                        
      specificToRGBAModel = toRGBAModel[colorModel]
      self.__colors = [ specificToRGBAModel(color) for color in values ]
        
  def getColors(self, colorModel='rgba'):
    """
    getColors(self, colorModel)

    returns:
      None for monocolor plots (all elements will be drawn in the same color)
      or
      List of colors in the specified color model
    """
    if self.__colors is None:
      return None
    else:
      # cut out the dictionary lookup in the list comprehension
      specificFromRGBAModel = fromRGBAModel[colorModel]
      return [ specificFromRGBAModel(color) for color in self.__colors]
  
  def clearColors(self):
    """
    clear plot colors
    """
    self.__colors = None
    
class RowColorMapper(ColorMapper):

  """

  This Class provides a very flexible dynamically adjustable color
  mapper.  Several methods are supplied to set colors by, in essence
  HSV color parameters can be mapped to several attributes of a
  dataset/labeling

  """

  def __init__(self, plotView):

    """
    This constructor does little more than call IPlotView's constructor
    """
    super(RowColorMapper, self).__init__(plotView, plotView.getDataset().getNumRows())


  def __setComponent(self, values, component):
    # for we're setting all the channels at once
    if component in colorModels:
      self.setColors(values, component)
    else:   
      zeros = [0] * len(values)
      ones = [1] * len(values)
      # set hsv values
      if component   == 'h':
        self.setColors(zip(values,ones,ones), 'hsv')
      elif component == 's':
        self.setColors(zip(ones,values,ones), 'hsv')
      elif component == 'v':
        self.setColors(zip(ones,ones,values), 'hsv')
      elif component   == 'r':
        self.setColors(zip(values,zeros,zeros), 'rgb')
      elif component == 'g':
        self.setColors(zip(zeros,values,zeros), 'rgb')
      elif component == 'b':
        self.setColors(zip(zeros,zeros,values), 'rgb')
      elif component == 'a':
        self.setColors(zip(zeros,zeros,zeros,values), 'rgba')
      else:
        raise ValueError("unexpected component type %s" %(component))
  
  def setColorByLabelingCounts(self, labeling, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):

    """
    sets the given color channel value proportional to the
    size of the class. valid channels are h,s,v, r,g,b or a
    """

    sizeDict = {}
    for label in labeling.getLabels():
      sizeDict[label] = len(labeling.getRowsByLabel(label))
     
    values = []
    for row in xrange(self.dataset.getNumRows()):
      label = labeling.getLabelByRow(row)
      values.append(sizeDict.get(label, 0))
    
    values = scaleList(values, minValue, maxValue, colorRange[0], colorRange[1])
    self.__setComponent(values, component)
                               
  def setColorByLabeling(self, labeling, labels=None, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):
  
    """
    setColorByLabeling(self, labeling=None, component='h', minValue=None, maxValue=None, colorRange(0,.7))

    uses the labeling to produces a color list for plotting.  If the
    labeling contains real values, then the optional parameters
    minValue and maxValue can be used to scale the colorRange used.
    Only uses the first label attached to every row

    """
    if labels is None:
      rowLabels = labeling.getLabelByRows()
    else:
      tmp = {}
      for l in labels:
        tmp[l] = l
      rowLabels = [tmp.get(l, None) for l in  labeling.getLabelByRows()]

    values = scaleList(rowLabels, minValue, maxValue, colorRange[0], colorRange[1])    
    self.__setComponent(values, component)

  def setColorByTupleLabeling(self, labeling, element=0, component='h', minValue=None, maxValue=None, colorRange=(0,.7)):
  
    """
    setColorByTupleLabeling(self, labeling=None, component='h', minValue=None, maxValue=None, colorRange(0,.7))

    uses the labeling to produces a color list for plotting.  If the
    labeling contains real values, then the optional parameters
    minValue and maxValue can be used to scale the colorRange used.
    Only uses the first label attached to every row

    """
    values = scaleList([ x[element] for x in labeling.getLabelByRows()], minValue, maxValue, colorRange[0], colorRange[1])
    self.__setComponent(values, component)
      
  def setColorByColValue(self, col, component = 'h', minValue=None, maxValue=None, colorRange=(0,.7)):
    
    """
    setHueByColValue(self,col, hRange(0,1))

    sets the color to be a function of the value for that row at a given column.
    """
    if self.dataset.getNumCols() == 0:
      raise ValueError("It is not possible to setColorByColValue when there are no columns")
    
    data = self.dataset.getColData(col)
    if maxValue is None:
      maxValue = nx.mlab.max(data)
    if minValue is None:
      minValue = nx.mlab.min(data)

    values = scaleList(data, minValue, maxValue, colorRange[0], colorRange[1])
    self.__setComponent(values, component)

  def setColorByFunction(self, function, component='h',  minValue=None, maxValue=None, colorRange=(0,.7)):

    """
    setColorByFunction(self, funct, component='h', minValue, maxValue, colorRange(0,.7))

    function should take in a dataset and a row and return a single
    float.  Colors will be scaled based on that value.
    
    function prototype should be:

    value = function(dataset, row)
    """
    
    data = [ function(self.dataset, row) for row in xrange(self.dataset.getNumRows())]
    # if our function is returning all the channels don't try to scale it.
    if component not in colorModels:
      data = scaleList(data, minValue, maxValue, colorRange[0], colorRange[1])
      
    self.__setComponent(data, component)
    
  def setColorByIndex(self, component = 'h', minValue=None, maxValue = None, colorRange=(0,.7)):
    """
    setColorByIndex converts the index range into a scaled color
    """
    values = scaleList(range(self.dataset.getNumRows()), minValue, maxValue, colorRange[0], colorRange[1])
    self.__setComponent(values, component)
  
    
class ClusterProbabilityColorMapper(ColorMapper):
  """
  Display probability of class memberships 
  """
  def __init__(self, plotView, probabilities):
    """Construct a probabilty Color mapper
    
    :Parameters:
      - `plotView`: one plotview from compClust.iplot.views
      - `probabilities`: a labeling containing probabilities
    """
    super(ClusterProbabilityColorMapper, self).__init__(plotView, plotView.getDataset().getNumRows())
    
    self.setProbabilityLabeling(probabilities)
    self.setHueByPartitioning()
    
  def setProbabilityLabeling(self, probabilities):
      
    self.probabilities = probabilities

  def __findRowMaxProbability(self, row):
    """Given a list of probabilities return the maximum and what index position it occured at
    """
    index = 0
    maximum = row[index]
    for i in range(1,len(row)):
      if row[i] > maximum:
        index = i
        maximum = row[i]
    return (maximum, index)
      
  def setHueByPartitioning(self, colorRange=(0,.7)):

    """
    setHueByPartioning.

    Colors each point's hue by the most likely cluster, and then it's
    value (intensity) is specified by the probability of membership to
    that particular cluster
    """
    row_probability = self.probabilities.getLabelByRows()
    # this causes 2 loops through 
    
    max_index = [ self.__findRowMaxProbability(row) for row in self.probabilities.getLabelByRows()]
    
    max_probability = [ x[0] for x in max_index ]
    index_probability = [ x[1] for x in max_index]
    
    hue = scaleList(index_probability, minReturn=colorRange[0], maxReturn=colorRange[1])    
    saturation = scaleList(max_probability, minReturn=0, maxReturn=1)
    value = [1] * len(hue)
    self.setColors(zip(hue,saturation, value), colorModel='hsv')

  def setHueByProbability(self, cluster, colorRange=(0,.7)):      
    """
    setHueByProbability(self, cluster, colorRange(0,1)):
    
    where cluster is the index of the probability tuple
      
    """
    hue = scaleList([ x[int(cluster)] for x in self.probabilities.getLabelByRows() ], minReturn=colorRange[0], maxReturn=colorRange[1])
    ones = [1] * len(hue)
    self.setColors(zip(hue,ones,ones), colorModel='hsv')
    