########################################
# The contents of this file are subject to the MLX PUBLIC LICENSE version
# 1.0 (the "License"); you may not use this file except in
# compliance with the License.
# 
# Software distributed under the License is distributed on an "AS IS"
# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See
# the License for the specific language governing rights and limitations
# under the License.
# 
# The Original Source Code is "compClust", released 2003 September 03.
# 
# The Original Source Code was developed by the California Institute of
# Technology (Caltech).  Portions created by Caltech are Copyright (C)
# 2002-2003 California Institute of Technology. All Rights Reserved.
########################################

"""
This module contains wrapper functions around varius python
plotting tools to make simple plotting easier.

Author:  Christopher Hart
Date  :  May 2001

"""

import string
import sping
import sping.PS 
import math
import MLab
import sys
import Gnuplot
import Numeric

from compClust.util import WrapperUtil
from compClust.util import Usage

def plot(values, yvalues=None, error=None, min=None, max=None, plotStyle="default", fileName=None, previousPlot=None ):

    """
    gPlot = plot(values, yvalues=None, error=None, plotStyle="default", fileName=None, previousPlot=None ):

    A gnuplot based plot function for matlab-like plotting.

        Usage: plot(x,y, <options>)

                     This creates a plot of the x-vector vs the
                     y-vector.  x and y can be either numeric arrays
                     or standard python lists

               plot(y , <options>)
        
                     This creates a plot of y-vector values vs thier index v
                     alues.  
        options:

              plotStyles = {line, points, bar}

              previousPlot: defaults to None, but does except a gnuplot object to
                            plot onto... allows adding more data to a plot

        Returns: The gnuplot object


    """
    tempArray = []
    # initialize gnuplot
    if previousPlot is None:
        try:
            g=Gnuplot.Gnuplot(persist=1)    
        except:
            import Gnuplot
            g=Gnuplot.Gnuplot(persist=1)    
    else:
        g=previousPlot


    # set up plot style
    g('set grid')

    # set up the proper plot style
    if (plotStyle=="default"):
        g('set data style linespoints')
    elif (plotStyle == "line"):
        g('set data style linespoints')
    elif (plotStyle == "points"):
        g('set data style points')
    elif (plotStyle == "bar"):
        g('set data style boxes')

    if yvalues:
        # Plot the x-vector vs the y-vector 
        xvalues=values
        if (len(xvalues) != len(yvalues)):
            sys.stderr.write( "Warning: xarray and yarray are different sizes- truncating to the shorter array\n")

    else:
        # Plot the values vs index
        yvalues = values
        xvalues = range(0, len(yvalues))

    tempArray = zip(xvalues, yvalues)


    if min and max:
        g('set yrange [%s:%s]'%(min,max))
    elif min:
        g('set yrange [%s:]'%min)
    elif max:
        g('set yrange [:%s]'%max) 
    
    if not(previousPlot):
        g.clear()
        g.plot(tempArray)
    else:
        g.replot(tempArray)

    if error:

        if len(xvalues) != len(yvalues) != len(error):
            sys.stderr.write( "Warning: Array lengths not equal, truncating to the shortest array\n")
        tmpErrorMatrix = zip(xvalues, yvalues, error)
        print tmpErrorMatrix
        errorData = Gnuplot.Data(tmpErrorMatrix, with='errorbars')
        g.replot(errorData)
        
    if fileName:
        g.hardcopy(filename=fileName, eps=1, color=1, solid=1)

    return(g)

def colorMap(data, fileName,  rowLabels=None, colLabels=None,
            dataMin=None, dataMax=None, printValues=0, title=None,
            previousCanvas=None, offset=(0,0), scaleBar=1,
            backgroundValue=None, backgroundColor=(0,0,0)):


    """
    cavas = colorMap(data, fileName, rowLabels=None, colLabels=None,
                     dataMin=None, dataMax=None, printValue=0, title=None)

    where:..
    
          data     : is a 2d Numeric array.

          filename : is a file where the postscript output will be
                     placed. 
           
          rowLabels: is a list of strings with the same number of rows
                     as the data array
          
          colLabels: is a list of strings with the same number of cols
                     as the data array

          dataMin,dataMax : the relitive minimum/maximum value, if none is
                    given it will use the data min

          printValues: if 1, then the numeric values will be printed in
                      each cell- this is undesirble if the size of
                      data is large

          title     : An optional title string
                      
          canvas     : is a sping.PS cavas

          backgroundValue : None 
          backgroundColor : (0,0,0) (This needs to be tuple of 0..1 values for R,G,B values)

    """

    # A couple of utility functions borrowed from sample2 distributed
    # with piddle/sping.

    def genRGBColor(value, dataMin=0, dataMax=1):
    
        """
        A simple RGB color generator.  This assumes that all
        values are normaized to fall between 0 and 1
        """

        #normalize the number to fall between 0 and 1 ba
        x = (float(value)-dataMin)/(dataMax-dataMin)

        blue = 1.0 / (1.0 + math.exp(-10*(x-0.6)))
        red =  1.0 / (1.0 + math.exp(10*(x-0.5)))
        green = (1 - pow( (1.0 / (1.0 + math.exp(10*((x+.2)-0.5)))) ,2) - (1.0 / (1.0 + math.exp(-10*((x-.3)-0.6)))))
        color = sping.PS.Color(red,green,blue)
        return(color)


    # this uses the PS backend - only because it is the one that is
    # easiest to get to work.

    # here we set up the basic geometry and position of things
    # relitive to the total canvas size and the data.shape

    # these are the criticle paramters
    
    size = (500,500)
    gridOrigin = (150,150)

    rows,cols = data.shape
    maximumFontSize = 8

    spingBackgroundColor = sping.PS.Color(backgroundColor[0],
                                          backgroundColor[1],
                                          backgroundColor[2])
        
    # setup the min and max values
    if dataMax == None:
        dataMax = MLab.max(MLab.max(data))
    if dataMin == None:
        dataMin = MLab.min(MLab.min(data))

    # these are all dependent variables
    gridSize = (min(size) - MLab.max(gridOrigin))/(MLab.max(data.shape)+2)
    legendHeight = min(size[1] - (gridOrigin[1]+((gridSize*(rows+1)))) - 50, gridSize)
    legendOrigin = (gridOrigin[0] , gridOrigin[1] + gridSize*(rows+1))
    legendNumberOfLines = 100  # works best if this is divisable by 4
    legendLineWidth = (size[0]-(legendOrigin[0]+50))/legendNumberOfLines

    if previousCanvas is None:
        canvas = sping.PS.PSCanvas(size=size)
    else:
        canvas = previousCanvas

    gridOrigin = (gridOrigin[0]+offset[0], gridOrigin[1]+offset[1])
    
    # initializs the coordinates
    x1 = gridOrigin[0]
    y1 = gridOrigin[1]
    x2 = gridOrigin[0] + gridSize
    y2 = gridOrigin[1] + gridSize

    # draw the grid
    for row in data:
        for element in row:
            if backgroundValue == element:
                color = spingBackgroundColor
            else:
                color = genRGBColor(element, dataMin, dataMax)
            canvas.drawRect(x1,y1,x2,y2, edgeWidth=0, fillColor=color)
            if printValues:
                canvas.drawString("%3.2f" % (element), x1+(gridSize/4),y1+(gridSize/2), sping.PS.Font(size=min(maximumFontSize, gridSize/4)))
            x1 += gridSize
            x2 += gridSize
        y1 += gridSize
        y2 += gridSize
        x1 = gridOrigin[0]
        x2 = gridOrigin[0] + gridSize

    # label the axis:
    if rowLabels:
        x1 = 1
        y1 = gridOrigin[1]+gridSize/2 
        for label in rowLabels:
            canvas.drawString(label, x1, y1, sping.PS.Font(size=min(maximumFontSize, gridSize*.5)))
            y1 += gridSize
    if colLabels:
        x1 = gridOrigin[1]+gridSize/2 
        y1 = gridOrigin[0] - gridSize/4 
        for label in colLabels:
            canvas.drawString(label, x1, y1, sping.PS.Font(size=min(maximumFontSize, gridSize*.5)), angle=90)
            x1 += gridSize


    if scaleBar == 1:
        # add a color legend:
        for i in range(legendNumberOfLines):

            x = dataMin + ((float(i)/legendNumberOfLines)*(dataMax-dataMin))

            canvas.drawLine(legendOrigin[0] + (i*legendLineWidth) , legendOrigin[1],
                            legendOrigin[0] + (i*legendLineWidth) , legendOrigin[1]+legendHeight,
                            genRGBColor(x, dataMin, dataMax), width=legendLineWidth)

            if (i%(legendNumberOfLines/4) == 0):
                canvas.drawString("%3.1f" % (x),
                                  legendOrigin[0]+(i*legendLineWidth),
                                  legendOrigin[1]-(gridSize/4),
                                  sping.PS.Font(size=min(maximumFontSize, gridSize*.25)))
    
        canvas.drawString("%3.1f" % (x), legendOrigin[0]+(i*legendLineWidth),
                          legendOrigin[1]-(gridSize/4),
                          sping.PS.Font(size=min(maximumFontSize,
                                                 gridSize*.25)))
    
        if title:
            canvas.drawString(title, gridOrigin[0], gridOrigin[1]+size[1], sping.PS.Font(size=gridSize/2))

    else:
        if title:
            canvas.drawString(title, gridOrigin[0], gridOrigin[1]-50, sping.PS.Font(size=gridSize/2))
            
    canvas.save(file=fileName)
    return(canvas)


def main(opts):

    """plot -s <line|points|bar> -o <filename> 

    This provides a simple plotting utility which can be driven at
    the command line.

    usage:

    plot <options> <optional-inputfile> 

    data format:

        Data is expected to passed in via stdin one vector per line.
        If a line, bar, or points plot is desired then either y values
        seperated by whitespace.  Currently x-value are the index
        domain of the y-values.

    options:

       -s, --style : Selects plot sytle, default is line.  Available
                     styles: line, points, bar, and
                     colorMap

       -m, --multiple-lines  :  Multiple lines

       -r, --row-vectors:  row vectors <default>

       -c, --col-vectors : column vectors 

       -o : --outfile; Required parameter if colormap style is selected

       -min : set the y min value
       -max : set the y max value

    """

    # parse the command line

    if opts.has_key('-h') or opts.has_key('--help'):
        Usage.showHelp(main, exit=1)
    
    if opts.has_key('-s') or opts.has_key('--style'):
        if opts.get('-s') == "line" or opts.get('--style') == "line":
            style = "line"
        elif opts.get('-s') == "points" or opts.get('--style') == "points":
            style = "points"
        elif opts.get('-s') == "bar" or opts.get('--style') == "bar":
            style = "bar"
        elif opts.get('-s') =="colormap" or opts.get('--style') == "colormap":
            style = "colormap"
        else:
            Usage.showHelp(main, exit=1)
    else:
        style = "line"

    # parse the stdin data

    data = []
    inputData = sys.stdin.readlines()
    numberOfDimensions = len(string.split(inputData[0]))
    for line in inputData:
        vector = string.split(line)
        try:
            vector = map(lambda x: float(x), vector)
        except:
            print "skipping non-numeric line %s" % (line)
        data.append(vector)

    data = Numeric.array(data)

    if opts.has_key('-c') or opts.has_key('--col-vectors'):
        data = Numeric.transpose(data)
    
    # make the requested plot

    if (style == "line" or style == "points" or style == "bar"):
        # uses the plot function
        if opts.has_key('-o') or opts.has_key('--outfile'):
            g = Gnuplot.Gnuplot(persist =1)
            g('set term postscript eps')
            outfile = opts.get('-o', opts.get('--outfile'))
            g('set output "'+ outfile +'"')
        else:
            g = None

        # g.title(s=opts.get('-t', opts.get('--title')))
        # g.xlabel(s=opts.get('-x', opts.get('--xtitle')))
        # g.ylabel(s=opts.get('-y', opts.get('--ytitle')))

        if (opts.has_key('-m') or
            opts.has_key('--multiple-lines') or
            Numeric.shape(data)[0] != 2):
            
            for row in data:
                print row
                g = plot(row, yvalues=None, min=opts.get('--min'), max=opts.get('--max'), plotStyle=style, previousPlot=g)

        else:
            try:
                values =  data[0, :]
                yvalues = data[1, :]
                g = plot(values, yvalues, min=opts.get('--min'), max=opts.get('--max'), plotStyle=sytle, previousPlot=g)

            except:
                print "Invalid Data Format"
                Usage.showHelp(main, exit=1)

    elif style == "colormap":
        # uses the colorMap function
        filename = opts.get('-o', opts.get('--outfile'))
        if not filename :
            Usage.showHelp(main, exit=1)
            
        colorMap(data, filename)


        
if __name__ == "__main__":

    try:
        opts, args = WrapperUtil.createOptTree('s:o:hmcrt:y:x:',
                                               ["help", "style=",
                                                "outfile=",
                                                "multiple-lines",
                                                "row-vectors",
                                                "col-vectors",
                                                "min=",
                                                "max=",
                                                "title=",
                                                "xtitle=",
                                                "ytitle="])
    except:
        Usage.showHelp(main, exit=1)
        
    main(opts)
    





