import operator

from IPlot import DatasetPlot
from views import DatasetRowPlotView
from compClust.mlx.views import RowSubsetView, ColumnSubsetView
from compClust.mlx.views import SortedView
from compClust.mlx import labelings
from compClust.iplot.mappers.DataMapper import PCADataMapper
from compClust.iplot.mappers.ColorMapper import RowColorMapper
from compClust.iplot.mappers.MarkersMapper import RowMarkersMapper, DEFAULT_MARKER_SIZE
import compClust.mlx.pcaGinzu

import matplotlib.numerix as nx
try:
  import matplotlib.pylab as pylab
except ImportError, e:
  # try older version of matplotlib matlab(r)-like commands
  import matplotlib.matlab as pylab

HIGH_RGB_COLOR = (0,0,1)
LOW_RGB_COLOR = (1,0,0)
# the original pca ginzu name was pcaGinzu, however all of the other plots
# in iplot were of the convention CapWord, so I changed the capitalization
# to make it more consistent
class PCAGinzu(compClust.mlx.pcaGinzu.pcaGinzu):
  """pcaGinzuVisualizeMatplotlib uses pcaGinzu to construct all the necessary labelings
  and then constructs non-interactive matplotlib plots showing the various outlier
  representations.
  """
  # NOTE, if you change the parameter list here, you'll need to change it in the
  # compClust.iplot.IPlot* modules, compClust.mlx.pcaGinzu, and PCAGinzu
  def __init__(self, canvasFactory, dataset, nOutliers=None, outlierCutoff=None, sigCutoff = 0.05, 
               maxPCNum = None, verbose = False, rowPCAView = None, makeLabelings = True,
               primaryLabeling=None, secondaryLabeling=None, primaryColumnLabeling=None):
    self.canvasFactory = canvasFactory
    if primaryLabeling is None:
      self.primaryLabeling = dataset.primaryRowLabeling
    else:
      self.primaryLabeling = primaryLabeling
    self.secondaryLabeling = secondaryLabeling

    if primaryColumnLabeling is None:
      self.primaryColumnLabeling = dataset.primaryRowLabeling
    else:
      self.primaryColumnLabeling = primaryColumnLabeling

    compClust.mlx.pcaGinzu.pcaGinzu.__init__(self, dataset, nOutliers, outlierCutoff, sigCutoff,
                                             maxPCNum, verbose, rowPCAView, makeLabelings)
                                             
  def getEigenVariances(self):
    """Return list of how much variance is captured by each eigen vecter
    (Note that the length of this vector is the total number of eigen vectors)
    """
    variances = self.rowPCAView.getVariances()
    return variances


  def getHighLowLabeling(self, pcNum):
    """
    Get the pca ginzu high low labeling for a pc component
    
    :Parameters:
      - `pcNum`: choose which PCAGinzu component labeling to retrieve, 
                 (from zero)
    """
    return self._getHighLowLabelingByPCN(pcNum+1)
  
  def getUpFlatDownLabeling(self, pcNum):
    """
    Return the labeling classifying the general trend of the vector
    
    :Parameters:
      - `pcNum`; choose which pcaginzu component to look for (zero based)
    """
    return self._getUpFlatDownLabelingByPCN(pcNum+1)
    
  def plotPCNOutlierRowsInOriginalColumnOrder(self, pcNum):
    """
    Return a tuple conatining a plot reference and the subset label
    we used to create that trajectory plot.
    
    The figure that shows a trajectory plot for the high and 
    low outliers.  The x-axis is ordered as the columns in the
    dataset are... original column order.    
    """
    if self.verbose:
      print "  plotPCNOutlierRowsInOriginalColumnOrder", pcNum
      
    l = self.getHighLowLabeling(pcNum)
    highrows = l.getRowsByLabel('high')
    lowrows = l.getRowsByLabel('low')

    subset = RowSubsetView(self.dataset, highrows + lowrows)
    dp = self.canvasFactory.getIPlot()
    dp.setTitle("PC %d Outlier Rows In Original Column Order"%(pcNum+1))
    rv = DatasetRowPlotView(subset, primaryLabeling=self.primaryLabeling, 
                                    secondaryLabeling=self.secondaryLabeling)

    rv.getColorMapper().setColors([HIGH_RGB_COLOR]*len(highrows) +
                                  [LOW_RGB_COLOR]*len(lowrows), 'rgb')

    conditions = self.dataset.primaryColumnLabeling
    if conditions is not None:
      dp.set_xticks_to_labeling(conditions)

    self.format_significant_condition_labels(dp, 
                                             self.getUpFlatDownLabeling(pcNum))
    dp.plot(rv)

    self.dataset.removeView(subset)
    return dp
                

  def plotPCNOutlierRowsInSigGroupOrder(self, pcNum):
    """
    Create and return tuple containing a matlplotlib figure and 
    the subset used to pick the data vectors.
    
    The plot contains outlier trajectories across conditions/dimensions 
    where the conditions are reordered based (approximately) on 
    significance of high vs. low.

    Note: Presently we order conditions/dimensions by mean diff, which is
    approximately ordered by significance of difference, but we SHOULD
    make this more precisely partitioned first into Up/Flat/Down, and
    then within group ordered by mean difference.  This ordering needs
    to correspond to the order of rows output by getOutputForSigGroups.
    """
    if self.verbose:
      print "  plotPCNOutlierRowsInSigGroupOrder", pcNum

    l = self.getHighLowLabeling(pcNum)
    highrows = l.getRowsByLabel('high')
    lowrows = l.getRowsByLabel('low')

    high_subset = RowSubsetView(self.dataset, highrows)
    if high_subset.getNumRows() > 0:
      high_mean = nx.average(high_subset.getData(),0)
    else:
      high_mean = 0
    self.dataset.removeView(high_subset)
    
    low_subset = RowSubsetView(self.dataset, lowrows)
    if low_subset.getNumRows() > 0:
      low_mean = nx.average(low_subset.getData(),0)
    else:
      low_mean = 0
    self.dataset.removeView(low_subset)
    
    subset = RowSubsetView(self.dataset, highrows+lowrows)

    mean_differences = high_mean - low_mean
    sorted_subset = SortedView(subset)
    temp = zip(mean_differences, range(len(mean_differences)))
    temp.sort()
    condition_ordering = [x[1] for x in temp]
    sorted_subset.permuteCols(condition_ordering)
    
    #sorted_subset = subset

    dp = self.canvasFactory.getIPlot()
    dp.setTitle("PC %d Outlier Rows In Mean Difference Order"%(pcNum+1))
    rv = DatasetRowPlotView(sorted_subset,
                            primaryLabeling=self.primaryLabeling, 
                            secondaryLabeling=self.secondaryLabeling)
    
    rv.getColorMapper().setColors([HIGH_RGB_COLOR]*len(highrows) +
                                  [LOW_RGB_COLOR]*len(lowrows),'rgb')

    conditions = self.dataset.primaryColumnLabeling
    if conditions is not None:
      dp.set_xticks_to_labeling(conditions, condition_ordering)

    self.format_significant_condition_labels(dp, 
                                             self.getUpFlatDownLabeling(pcNum),
                                             condition_ordering)
    dp.plot(rv)

    subset.removeView(sorted_subset)
    self.dataset.removeView(subset)    
    return dp

  def plotPCvsPCWithOutliersInY(self, pcNumForXAxis=0, pcNumForYAxis=1):
    l = self.getHighLowLabeling(pcNumForYAxis)
    highrows = l.getRowsByLabel('high')
    lowrows = l.getRowsByLabel('low')

    variances = self.rowPCAView.getVariances()
    rotMatrix = self.rowPCAView.matrix
    
    # make plotView
    plotView = DatasetRowPlotView(self.rowPCAView, primaryLabeling=self.primaryLabeling, secondaryLabeling=self.secondaryLabeling)
        
    # configure data mapper
    pcaDataMapper = PCADataMapper(plotView)
    pcaDataMapper.setXColumn(pcNumForXAxis)
    pcaDataMapper.setYColumn(pcNumForYAxis)
    plotView.setDataMapper(pcaDataMapper)
    
    # configure color mapper
    def color_outliers(dataset, row, labeling=l):
      label = labeling.getLabelByRow(row)
      if label == 'high':
        return HIGH_RGB_COLOR
      elif label == 'low':
        return LOW_RGB_COLOR
      else:
        return (0,0,0)        
    pcaColorMapper = RowColorMapper(plotView)
    pcaColorMapper.setColorByFunction(color_outliers, 'rgb')
    plotView.setColorMapper(pcaColorMapper)
    
    # configure marker sizes
    def markers_mapper(dataset, row, labeling=l):
      label = labeling.getLabelByRow(row)
      if label in ('high', 'low'):
        return DEFAULT_MARKER_SIZE
      else:
        return DEFAULT_MARKER_SIZE / 10.0
    pcaMarkerMapper = RowMarkersMapper(plotView)
    pcaMarkerMapper.setMarkerSizeByFunction(markers_mapper)
    plotView.setMarkersMapper(pcaMarkerMapper)
    
    pcaPlot = self.canvasFactory.getDatasetPlot()
    try:
      pcaPlot.axis.set_aspect('equal', True, 'upperleft')
    except AttributeError, e:
      # oh well, I guess we don't have matplotlib >= 0.84
      pass 
    # attach legends
    pcaPlot.setTitle("PC %d: %d highest and %d Lowest Outliers" % (
                                                                 pcNumForYAxis+1,
                                                                 len(highrows),
                                                                 len(lowrows)))
    #pcaPlot.axis.set_title('PCA Space', fontdict=pcaPlot.title_font)
    percentages = self.rowPCAView.getVariances()
    pcaPlot.axis.set_xlabel('PC %d (%4.2f%%)' % (pcNumForXAxis+1, percentages[pcNumForXAxis]*100))
    pcaPlot.axis.set_ylabel('PC %d (%4.2f%%)' % (pcNumForYAxis+1, percentages[pcNumForYAxis]*100))

    # plot it    
    pcaPlot.scatter(plotView)
    
    #self.setAxisSquare()
    return pcaPlot
    

  def format_significant_condition_labels(self, dataset_plot, sig_labeling, ordering=None):
    """Given a labeling that defines significance indicate this on the plot
    (also rotate the text so its more of the label is likely to show up on the plot)
    """
    # note: significat_columns is a list of lists
    significant_columns = sig_labeling.getAllColLabels()
    if ordering is not None:
      significant_columns = [ significant_columns[x] for x in ordering ]
    xticklabels = dataset_plot.axis.get_xticklabels()
    for i in range(len(xticklabels)):
      xtick = xticklabels[i]
      xtick.set_rotation(45)
      if significant_columns[i][0] != "flat":
        xtick.set_fontsize('large')


  def generateResults(self,rowLabelingNames,colLabelingNames):
    """
    For each principal component, call the major result generating
    functions and save those results to appropriately-named files.
    """
    # Within this function pcNum is 1-origin to match how users would call 
    # the various user-level functions
    for pcNum in range(1,self.maxPCNum+1):

      if self.verbose:
        print '  Writing results for PC %d' % pcNum

      if pcNum > 1:
        f = self.plotPCvsPCWithOutliersInY(pcNum-1, pcNum)
        pylab.savefig('pc%02d-outliers' % pcNum)

      f = self.plotPCNOutlierRowsInSigGroupOrder(pcNum)
      pylab.savefig('pc%02d-outlier-trajectories-order-meandiff' % pcNum)

      output = self.getOutputForPCNOutliers(pcNum,rowLabelingNames)
      if output is not None:
        write2DStringArrayToFile(output, 'pc%02d-outliers.txt' % pcNum)

      output = self.getOutputForSigGroups(pcNum,colLabelingNames)
      if output is not None:
        write2DStringArrayToFile(output, 'pc%02d-condition-groups.txt' % pcNum)


def write2DStringArrayToFile(stringArray, filename, delim='\t'):
  """
  Simple utility function to spew a 2D string array to a tab-delimited
  text file.
  """
  fd = open(filename,'w')
  for row in stringArray:
    for col in row:
      fd.write('%s%s' %(col,delim))
    fd.write('\n')
  fd.close()
