"""
    usage: python scatterfields.py infilename xaxisLabel xField yaxisLabel yField outImageName [--xmin xMin] [--ymin yMin]
                  [--xmax xMax] [--ymax yMax] [--doLogF1] [--doLogF2] [--arcsinh] [--order polyOrder] [--base logBase]
                  [--markGenes geneFile] [--markfold times] [--noregression] [--large] [--markdiag] [--title text] [--verbose]

           Do a scatter plot of 2 fields from an input file.
           fields are counted from 0.
           use [-order polyOrder] to specify polynomial fits > 1
           Supports very rudimentary compound fields for X value
           using python's lambda functions (omit the keyword lambda)
"""

import matplotlib
matplotlib.use("Agg")

from pylab import *
import math, cmath
import sys
import optparse
from commoncode import getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption

alphaVal = 0.5

print "scatterfields: version 3.2"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = __doc__

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 6:
        print usage
        sys.exit(1)

    infile = open(args[0])
    xaxis = args[1]
    xField = args[2]
    yaxis = args[3]
    yField = int(args[4])
    outfilename = args[5]

    scatterfields(infile, xaxis, xField, yaxis, yField, outfilename, options.forcexmin,
                  options.forceymin, options.forcexmax, options.forceymax, options.doLogF1,
                  options.doLogF2, options.doArcsinh, options.fitOrder, options.base,
                  options.markFile, options.foldChange, options.doRegression, options.plotLarge,
                  options.markDiag, options.figtitle, options.verbose)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--xmin", type="float", dest="forcexmin")
    parser.add_option("--ymin", type="float", dest="forceymin")
    parser.add_option("--xmax", type="float", dest="forcexmax")
    parser.add_option("--ymax", type="float", dest="forceymax")
    parser.add_option("--doLogF1", action="store_true", dest="doLogF1")
    parser.add_option("--doLogF2", action="store_true", dest="doLogF2")
    parser.add_option("--arcsinh", action="store_true", dest="doArcsinh")
    parser.add_option("--order", type="int", dest="fitOrder")
    parser.add_option("--base", type="int", dest="base")
    parser.add_option("--markGenes", dest="markFile")
    parser.add_option("--markfold", type="float", dest="foldChange")
    parser.add_option("--noregression", action="store_false", dest="doRegression")
    parser.add_option("--large", action="store_true", dest="plotLarge")
    parser.add_option("--markdiag", action="store_true", dest="markDiag")
    parser.add_option("--title", type="int", dest="figtitle")
    parser.add_option("--verbose", action="store_true", dest="verbose")

    configParser = getConfigParser()
    section = "scatterfields"
    forcexmin = getConfigFloatOption(configParser, section, "forcexmin", 0.0)
    forceymin = getConfigFloatOption(configParser, section, "forceymin", 0.0)
    forcexmax = getConfigIntOption(configParser, section, "forcexmax", -1)
    forceymax = getConfigIntOption(configParser, section, "forceymax", -1)
    doLogF1 = getConfigBoolOption(configParser, section, "doLogF1", False)
    doLogF2 = getConfigBoolOption(configParser, section, "doLogF2", False)
    doArcsinh = getConfigBoolOption(configParser, section, "doArcsinh", False)
    fitOrder = getConfigIntOption(configParser, section, "fitOrder", 1)
    base = getConfigIntOption(configParser, section, "base", 10)
    markFile = getConfigOption(configParser, section, "markFile", None)
    foldChange = getConfigOption(configParser, section, "foldChange", None)
    doRegression = getConfigBoolOption(configParser, section, "doRegression", True)
    plotLarge = getConfigBoolOption(configParser, section, "plotLarge", False)
    markDiag = getConfigBoolOption(configParser, section, "markDiag", False)
    figtitle = getConfigOption(configParser, section, "figtitle", "")
    verbose = getConfigBoolOption(configParser, section, "verbose", False)

    parser.set_defaults(forcexmin=forcexmin, forceymin=forceymin, forcexmax=forcexmax, forceymax=forceymax, doLogF1=doLogF1,
                        doLogF2=doLogF2, doArcsinh=doArcsinh, fitOrder=fitOrder, base=base, markFile=markFile,
                        foldChange=foldChange, doRegression=doRegression, plotLarge=plotLarge, markDiag=markDiag,
                        figtitle=figtitle, verbose=verbose)

    return parser


def scatterfields(infilename, xaxis, xField, yaxis, yField, outfilename, forcexmin=0.0, forceymin=0.0,
                  forcexmax=-1, forceymax=-1, doLogF1=False, doLogF2=False, doArcsinh=False, fitOrder=1,
                  base=10, markFile=None, foldChange=None, doRegression=True, plotLarge=False,
                  markDiag=False, figtitle="", verbose=False):

    infile = open(infilename)
    compoundField = False
    try:
        xField = int(xField)
    except:
        try:
            compoundOp = "lambda %s" % xField
            operator = eval(compoundOp)
            compoundField = True
            print "compound field %s" % xField
        except:
            pass

        if not compoundField:
            print "expression %s not supported" % xField
            sys.exit(1)

    markedGenes = []
    marking = False
    if markFile is not None:
        for line in markFile:
            try:
                markedGenes.append(line.strip().split()[0].upper())
            except:
                markedGenes.append(line.strip().upper())
       
        markFile.close()
        marking = True

    markFold = False
    if foldChange is not None:
        markFold = True

    newscores = []
    oldscores = []

    markednewscores = []
    markedoldscores = []

    markedfoldnewscores = []
    markedfoldoldscores = []

    ymax = 0.
    xmax = 0.
    for line in infile:
        fields = line.strip().split()
        gene = fields[0]
        try:
            if compoundField:
                score = operator(fields)
            else:
                score = float(fields[xField])

            newscore = float(fields[yField])
        except:
            continue

        foldMarkThisScore = False
        if markFold:
            tempscore = score
            if tempscore == 0:
                tempscore = 0.03

            tempratio = newscore / tempscore
            if tempratio == 0:
                tempratio2 = tempscore / 0.03
            else:
                tempratio2 = 1. / tempratio

            if tempratio > foldChange or tempratio2 > foldChange:
                foldMarkThisScore = True

        if doArcsinh:
            score = abs(cmath.asinh(score))
        elif doLogF1:
            try:
                score = math.log(score, base)
            except:
                score = forcexmin

            if score > xmax:
                xmax = score

        if doArcsinh:
            newscore = abs(cmath.asinh(newscore))
        elif doLogF2:
            try:
                newscore = math.log(newscore, base)
            except:
                newscore = forceymin

            if newscore > ymax:
                ymax = newscore

        oldscores.append(score)
        newscores.append(newscore)
        if foldMarkThisScore:
            markedfoldoldscores.append(score)
            markedfoldnewscores.append(newscore)
            if marking and gene.upper() not in markedGenes:
                print gene, score, newscore, "unmarked"

            if gene.upper() in markedGenes:
                print gene, score, newscore, "overfold"

            if verbose:
                print len(markedfoldoldscores), line.strip()

        if gene.upper() in markedGenes:
            if not foldMarkThisScore:
                print gene, score, newscore

            markedoldscores.append(score)
            markednewscores.append(newscore)

    print score, newscore
    print fields

    if plotLarge and markFold:
        plot(oldscores, newscores, "^", markersize=10., color="0.75", alpha=alphaVal)
    elif plotLarge:
        plot(oldscores, newscores, "b^", markersize=10., alpha=alphaVal)
    elif markFold:
        plot(oldscores, newscores, ",", color="0.75", alpha=alphaVal)
    else:
        plot(oldscores, newscores, "b,", alpha=alphaVal)

    if len(markedfoldoldscores) > 0:
        if plotLarge:
            plot(markedfoldoldscores, markedfoldnewscores, "b^", markersize=10., alpha=alphaVal)
        else:
            plot(markedfoldoldscores, markedfoldnewscores, "b,", alpha=alphaVal)

    if len(markedoldscores) > 0:
        if plotLarge:
            plot(markedoldscores, markednewscores, "r^", color="red", markersize=10., alpha=alphaVal)
        else:
            plot(markedoldscores, markednewscores, ".", color="red", markersize=4., alpha=alphaVal)

    fitvalues = polyfit(oldscores, newscores, fitOrder)
    print fitvalues
    print len(oldscores)

    meanObserved = float(sum(newscores)) / len(newscores)
    if len(fitvalues) == 2:
        predicted = [(fitvalues[0] * x + fitvalues[1]) for x in oldscores]
    else:
        predicted = [(fitvalues[0] * x**2 + fitvalues[1] * x + fitvalues[2]) for x in oldscores]

    SSt = 0.
    SSe = 0.

    for index in range(len(newscores)):
        SSt += (newscores[index] - meanObserved) ** 2
        SSe += (newscores[index] - predicted[index]) ** 2

    rSquared = 1. - SSe / SSt
    print "R**2 = %f" % rSquared

    oldscores.sort()
    if len(fitvalues) == 2:
        predicted = [(fitvalues[0] * x + fitvalues[1]) for x in oldscores]
    else:
        predicted = [(fitvalues[0] * x**2 + fitvalues[1] * x + fitvalues[2]) for x in oldscores]

    if doRegression:
        plot(oldscores, predicted, "-k", linewidth=2)

    if figtitle == "":
        figtitle = "%s vs %s (R^2: %.2f)" % (yaxis, xaxis, rSquared)

    title(figtitle)

    if markDiag:
        min = forcexmin
        if forceymin < min:
            min = forceymin

        max = xmax
        if ymax > max:
            max = ymax

        if forcexmax > max:
            max = forcexmax

        if forceymax > max:
            max = forceymax

        plot([min,max], [min,max], "-g", linewidth=2)

    print forcexmin, forceymin

    if doLogF2:
        ylabel("log%s(%s)" % (str(base), yaxis))
    else:
        ylabel(yaxis)

    if doLogF1:
        xlabel("log%s(%s)" % (str(base), xaxis))
    else:
        xlabel(xaxis)

    if xmax > 0:
        xlim(forcexmin - 0.05, xmax)

    if ymax > 0:
        ylim(forceymin - 0.05, ymax)

    if forcexmax > 0 and forceymax > 0:
        xlim(forcexmin - 0.05, forcexmax)
        ylim(forceymin - 0.05, forceymax)

    gca().get_xaxis().tick_bottom()
    gca().get_yaxis().tick_left()

    savefig(outfilename, dpi=100)


if __name__ == "__main__":
    main(sys.argv)