import sets
from compClust.iplot import Plot


class VennDiagram(Plot.Plot):
  """
  A simple class to display a venn diagram given 3 sets of items.  If these
  items happen to be labels from a labeling on a compClust dataset, then
  interaction is provided to view the intersects.  """

  def __init__(self, canvasFactory, set1, set2, set3=None, set1Name='', set2Name='', set3Name='', labeling=None, displayFunction=None):
    """
    Given set1, set2, set3 (optional) (any collection type should suffice) and
    the labeling (if these sets contain labels from a labeling) render the venn
    diagram.  Optionally if an alternative display function is desired on a
    mouse click, it can be passed in.  The displayFunction should expect a mlx
    view.
    """
    Plot.Plot.__init__(self, canvasFactory=canvasFactory)
    # cast everything to sets.
    if not isinstance(set1, sets.Set):
      set1 = sets.Set(set1)
    if not isinstance(set2, sets.Set):
      set2 = sets.Set(set2)
    if set3 is not None:
      if not isinstance(set3, sets.Set):
        set3 = sets.Set(set3)
    
    # render the figure. 
    if set3:
      centersX = (-1, 0, 1)
      centersY =  (0, 1.5, 0)
      colors  = [0,.5,1]
    else:
      centersX = (-1, 1)
      centersY =  (0, 0)
      colors = [0, 1]
    radius = 1.5
    
    axis = self.figure.add_subplot('111')
    pts = axis.scatter_classic(centersX, centersY, radius, colors)
    [p.set_alpha(.4) for p in pts]
    axis.set_ylim([-4,4])
    axis.set_xlim([-4,4]) 
    textProps = {'horizontalalignment':'center',
                 'verticalalignment': 'center'}

    if set3:
      # label in the diagram
      axis.text(-3,0, set1Name,textProps, rotation='vertical')
      axis.text( 0,3.5, set2Name,textProps, rotation='horizontal')
      axis.text( 3,0, set3Name,textProps, rotation='vertical')
      # fill in the numbers.
      axis.text(-1.5,-.5 , str(len(set1 - set3 - set2)), textProps)
      axis.text( 0.0,2.2 , str(len(set2 - set1 - set3)), textProps)
      axis.text( 1.5,-.5 , str(len(set3 - set2 - set1)), textProps)
      axis.text(-1.0, 1.0, str(len((set1 & set2) - set3)), textProps)
      axis.text( 1.0, 1.0, str(len((set2 & set3) - set1)), textProps)
      axis.text( 0.0,-0.7, str(len((set3 & set1) - set2)), textProps)
      axis.text( 0.0, 0.0, str(len(set1 & set2 & set3)), textProps)
    else:
      # label in the diagram
      axis.text(-3,0, set1Name,textProps, rotation='vertical')
      axis.text( 3,0, set2Name,textProps, rotation='vertical')
      # fill in the numbers
      axis.text(-1, 0, str(len(set1-set2)), textProps)
      axis.text( 1, 0, str(len(set2-set1)), textProps)
      axis.text( 0, 0, str(len(set1 & set2)), textProps)


 
