"""Plot is the most base clase of the IPlot/PlotPage etc hierarchy.

Currently the point of it is to manage the matplotlib backend mixins
defined by the CanvasFactory hierarchy. 
"""
import os
import tempfile

from matplotlib.figure import Figure
from matplotlib import rcParams
from matplotlib.numerix import sqrt, nonzero, equal, asarray
from matplotlib.numerix.mlab import amin

from compClust.util.NaN import inf

import CanvasFactory

import StringIO
PILAvailable = True
try:
  import Image
except ImportError, e:
  PILAvailable = False

class Plot:
  """Plot is the base class of IPlot and PlotPage

  It manages the canvases produced by the CanvasFactory mixins
  """
  def __init__(self, canvasFactory=None, axis=None, figsize=None, minsize=None, dpi=100):
    """Store things nee
    PlotPage allows one to render multiple plots in a single graph.
    
    Unfortuantely, The layout needs to be preallocated so PlotPage
    requires that at least the number of rows are specified.
    It defaults to 3 columns, though you can override that if necessary.

    :Parameters:
      - `numRows`: The number of rows in the plot
      - `numCols`: The number of columns in the plot, defaults to 3
      - `canvasFactory`: factory object for the different matplotlib backends
      - `figsize`: A tuple, specifying the size of the final plot. If None,
                   the figure size defaults to (6,6)
      - `minsize`: The minimum size a figure should be
    """
    if minsize is None:
      minsize = (6,6)
    if figsize is None:
      figsize = (6,6)
    figsize = (max(figsize[0], minsize[0]), max(figsize[1], minsize[1]))
    self.canvas = None
    self.canvasFactory = canvasFactory
    self.__id = None
    self.title_font = {'family': 'sans-serif',
                       'fontsize': 10 }
    if axis is None:
      self.figure = Figure(figsize=figsize,dpi=dpi)

      self.createCanvas()
    else:
      self.figure = axis.figure
    self.axis = axis
    
    # persistence feature
    self.__image_filename = None
    self.image_extension = ".png"

  def __get_id(self):
    """Return a plot id"""
    if self.__id is None:
      self.__id = id(self)
    return self.__id
  def __set_id(self, plotid):
    """set the plot id
    """
    self.__id = plotid
  id = property(__get_id, __set_id, doc="provide a way of accessing an identifier for a plot")
  
  def createCanvas(self):
    """Construct a canvas for our figure
    """
    # do we have a canvas factory?
    if self.canvasFactory is None:
      return

    # Summon canvas
    self.canvas = self.canvasFactory.createCanvas(self.figure)
    try:
      #self.canvas.show()
      
      # bind canvas
      self.canvas.mpl_connect('button_press_event', self.onClick)
    except AttributeError:
      # FIXME: some backends don't have canvas.show/connect, so we'll
      # FIXME: ignore it right now
      pass
    
  def getCanvasFactory(self):
    return self.canvasFactory

  def setCanvasFactory(self, factory):
    if factory is None or isinstance(factory, CanvasFactory.CanvasFactory):
      old_factory = self.canvasFactory
      self.canvasFactory = factory
      # if we have a factory and it's new make a new canvas
      # we should create a new canvas since someone went through the
      # trouble of giving us one
      if factory is not None and factory != old_factory:
        self.createCanvas()
    else:
      raise ValueError("Factory must be subclass of CanvasFactory")

  def findNearestPoint(self, click_x, click_y):
    """Finds the closest circle to where we clicked,

    this involves an exhaustive search, and is slow

    x,y is in display space
    """
    def compute_epsilon(artist):
      # computer how big the region for hit testing should be for each axis
      data_box = artist.get_transform().get_bbox1()
      display_box = artist.get_transform().get_bbox2()
      epsilon_x = 5/display_box.width() * data_box.width()
      epsilon_y = 5/display_box.height() * data_box.height()
      return min(epsilon_x, epsilon_y)
    
    def find_closest_scatter(distances):
      """Find the smallest distance in a point in a scatter plot
      """
      if len(distances) < 1:
        return None
            
      closest_d = distances[0]
      closest_index = 0
      for i in xrange(1, len(distances)):
        if distances[i] < closest_d:
          closest_index = i
          closest_d = distances[i]
      return closest_index
      
    def find_closest_vector(distances):
      """Find the smallest distance in the distance vector
      """
      if len(distances) < 1:
        return None
            
      closest_d = distances[0][0]
      closest_index = 0
      closest_condition = 0
      for row in xrange(len(distances)):
        for col in xrange(len(distances[row])):
          if distances[row][col] < closest_d:
            closest_index = row
            closest_condition = col
            closest_d = distances[row][col]
      return (closest_index, closest_condition)

    data_click_x,data_click_y = self.axis.transData.inverse_xy_tup((click_x, click_y))
    #x,y = click_x, click_y

    # search through collections ( on scatter plots )
    for collection in self.axis.collections:
      epsilon = compute_epsilon(collection)
      # compute the distances between all the data points on the users click
      d = sqrt([((x-data_click_x)**2+(y-data_click_y)**2) for x,y in collection._offsets])
      # go hunting for the closest point
      closest_index = find_closest_scatter(d)
      if d[closest_index] < epsilon:
        if collection._labels is not None:
          label = collection._labels.getAllRowLabels()[closest_index]
        else:
          label = None
        return (label, collection._offsets[closest_index])
    # search through lines ( on line plots)
    # compute distances
    if len(self.axis.lines) > 0:
      epsilon = compute_epsilon(self.axis.lines[0])
      distances = []
      for l in self.axis.lines:
        xdata = asarray(l.get_xdata())
        ydata = asarray(l.get_ydata())
        distances.append(sqrt((xdata-data_click_x)**2+(ydata-data_click_y)**2))
      closest_index, closest_condition = find_closest_vector(distances)
      if distances[closest_index][closest_condition] < epsilon:
        closest_line = self.axis.lines[closest_index]
        x = closest_line.get_xdata()[closest_condition]
        y = closest_line.get_ydata()[closest_condition]
        label = self.axis.lines[closest_index].get_label()
        print label, (x,y)
        return (label, (x,y))
      else:
        return ("No label found", None)
    return (None, None)

  def findPatch(self, x, y):
    """Find whatever patch object the user clicked on.

    x,y needs to be in display space
    """
    click = lbwh_to_bbox(x,y,1,1)
    for p in self.axis.patches:
        if p.get_window_extent().overlap(click):
            return p
    return None

  def highlightLine(self, line, update=False):
    """Highlight a line 
    """
    for l in self.axis.lines:
      if l != line:
        l.set_alpha(0.1)
        l.set_linewidth(0.5)
      else:
        l.set_alpha(1)
        l.set_linewidth(2)
    if update:
      self.show()
    
  def onClick(self, event):
    """
    """
    #print "onClick.widget: ", widget
    print "onClick.event : ", event
    raise NotImplementedError("onClick needs to be overriden")

  
  #######
  # Image persistence functions
  def __get_image_filename(self):
    """Write plot to disk and return the temp filename
    """
    if self.__image_filename is None or not os.path.isfile(self.__image_filename):
      # save the image somewhere
      self.save(self.__image_filename)
    return self.__image_filename
  image_filename = property(__get_image_filename, doc="filename of stored image")

  def save(self, filename=None, **kwargs):
    """Save plot to filename takes the following optional arguments
    
      dpi = 150  plot resolution
      facecolor = 'w' 
      edgecolor = 'w' (colors of the figure rectangle)
      orientation = 'portrait' either 'landscape' or 'portrait'
        (not supported on all backends; currently only on postscript output.
    """
    image_fd = -1
    # if we need compe up with a tempfilename
    if filename is None:
      if self.__image_filename is None or not os.path.isfile(self.__image_filename):
        image_fd, self.__image_filename=tempfile.mkstemp(self.image_extension)
      filename = self.__image_filename
    # largely lifted from matplotlib.matlab.savefig
    for key in ('dpi', 'facecolor', 'edgecolor'):
      if not kwargs.has_key(key):
        kwargs[key] = rcParams['savefig.%s'%(key)]
    self.canvas.print_figure(filename, **kwargs)
    # close file created by mkstemp
    if image_fd != -1:
      os.close(image_fd)
    return filename
  
  def tostring(self, encoding="RAW", size=None):
    """Convert the canvas into data stored in a string.

    encoding type RAW returns the raw image, all other encoding types are
    use PIL to try and convert them to the specified image format.

    size is a tuple of pixel sizes, if none, tostring will attempt to
    determine the current image size.
    """
    
    # attempt to create png without touching the disk
    if not hasattr(self.canvas, 'renderer'):
      self.canvas.draw()
      
    rgb_data = self.canvas.tostring_rgb()

    if encoding == "RAW":
      return rgb_data
    else:
      if not PILAvailable:
        raise RuntimeError("PIL not available, can't convert images")

      if size is None:
        l,b,w,h = self.figure.bbox.get_bounds()
        size = (int(w), int(h))
        
      rgb_fp = StringIO.StringIO()
      i = Image.fromstring("RGB", size, rgb_data)
      i.convert("RGB").save(rgb_fp, encoding)
      return rgb_fp.getvalue()
    
    
  def show(self):
    """Update plot 
    """
    if self.canvas is not None:
      try:
        self.canvas.show()
      except AttributeError, e:
        # ignore access error if show happens to be missing
        pass

  # Configure plot style
  def title(self, text):
    """Add a title to the figure (as opposed to particular axes
    """
    # FIXME: it'd be nice if we figured out the distance between the top
    # FIXME: of the figure and the top row of axes to determine what the
    # FIXME: y value should be. (Other than a value in the range [0,1])
    self.figure.text(.5, .96, text, verticalalignment='top', horizontalalignment='center')

  def set_xticks_to_labeling(self, col_labeling, col_ordering=None):
    """Set label names for x axis
    """
    # FIXME: need some way of setting a default condition labeling
    if col_labeling is None:
      raise ValueError("requires labeling")
    col_labels = [ x[0] for x in col_labeling.getAllColLabels()]
    if col_ordering is not None:
      col_labels = [ col_labels[x] for x in col_ordering ]
    self.axis.set_xticks(range(len(col_labels)))
    self.axis.set_xticklabels(col_labels)
