#
#  plotbardist.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 12/13/07.

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import matplotlib
from pylab import *
from math import *
from commoncode import getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption


print "plotbardist: version 3.3"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog infile1 [infile2] [infile3] [options] outfile.png"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])


    if len(args) < 2 or len(args) > 4:
        print usage
        print "where labelList and legendList are comma delimited strings of the form 'labelA,labelB,...,labelN'"
        sys.exit(1)

    fileList = args[:-1]
    pngfilename = args[-1]

    plotbardist(fileList, pngfilename, options.bins, options.binnedField, options.binLength,
                options.logBase, options.maxY, options.xLabel, options.yLabel, options.binLabels,
                options.figTitle, options.barsLegend, options.pointOffset, options.figSizes)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--bins", type="int", dest="bins")
    parser.add_option("--field", type="int", dest="binnedField")
    parser.add_option("--binSize", type="float", dest="binLength")
    parser.add_option("--doLog", type="int", dest="logBase")
    parser.add_option("--ymax", type="int", dest="maxY")
    parser.add_option("--xlabel", dest="xLabel")
    parser.add_option("--ylabel", dest="yLabel")
    parser.add_option("--binLabels", dest="binLabels", help="comma separated list")
    parser.add_option("--title", dest="figTitle")
    parser.add_option("--legend", dest="barsLegend", help="comma separated list")
    parser.add_option("--xoffset", type="float", dest="pointOffset")
    parser.add_option("--figsize", dest="figSizes", help="x,y pair")

    configParser = getConfigParser()
    section = "plotbardist"
    bins = getConfigIntOption(configParser, section, "bins", 10)
    binnedField = getConfigIntOption(configParser, section, "binnedField", -1)
    binLength = getConfigIntOption(configParser, section, "binLength", -1)
    logBase = getConfigOption(configParser, section, "logBase", None)
    maxY = getConfigIntOption(configParser, section, "maxY", 0)
    xLabel = getConfigOption(configParser, section, "xLabel", "bins")
    yLabel = getConfigOption(configParser, section, "yLabel", "count")
    binLabels = getConfigOption(configParser, section, "binLabels", None)
    figTitle = getConfigOption(configParser, section, "figTitle", "")
    barsLegend = getConfigOption(configParser, section, "barsLegend", None)
    pointOffset = getConfigFloatOption(configParser, section, "pointOffset", 0.)
    figSizes = getConfigOption(configParser, section, "figSizes", None)

    parser.set_defaults(bins=bins, binnedField=binnedField, binLength=binLength, logBase=logBase, maxY=maxY,
                        xLabel=xLabel, yLabel=yLabel, binLabels=binLabels, figTitle=figTitle,
                        barsLegend=barsLegend, pointOffset=pointOffset, figSizes=figSizes)

    return parser


def plotbardist(fileList, pngfilename, bins=10, binnedField=-1, binLength=-1, logBase=None,
                maxY=0, xLabel="bins", yLabel="count", binLabels=None, figTitle="",
                barsLegend=None, pointOffset=0., figSizes=None):

    matplotlib.use("Agg")
    plotParameters = {1: {"width": 0.5,
                          "offset": [-0.25]},
                      2: {"width": 0.3,
                          "offset": [-0.3, 0]},
                      3: {"width": 0.2,
                          "offset": [-0.2, 0., 0.2]}
    }

    colorList = ["b", "r", "c"]
    width = plotParameters[len(fileList)]["width"]
    offset = plotParameters[len(fileList)]["offset"]

    doLog = False
    if logBase is not None:
        doLog = True
        print "taking log%d of x datapoints" % logBase
        xLabel = "log%d(%s)" % (logBase, xLabel)
    else:
        logBase = 10

    if figSizes is not None:
        sizes = figSizes.strip().split(",")
        figure(figsize=(float(sizes[0]),float(sizes[1])))

    doLabels = False
    if binLabels is not None:
        binLabels = binLabels.strip().split(",")
        doLabels = True
    else:
        binLabels = []

    if barsLegend is not None:
        barsLegend = barsLegend.strip().split(",")
    else:
        barsLegend = []
    
    ind2 = arange(bins)

    bars = []
    barsColors = []
    index = 0
    for fileName in fileList:
        aFile = open(fileName)
        distbin = bins * [0]

        dataList = []
        for line in aFile:
            fields = line.strip().split()
            try:
                point = float(fields[binnedField]) + pointOffset
                if doLog:
                    if point < 1:
                        point = 1

                    point = log(point, logBase)

                dataList.append(point)
            except:
                continue

        print "%d data points" % len(dataList)

        dataList.sort()
        print "low = %f high = %f" % (dataList[0], dataList[-1])

        if binLength < 0:
            binLength = abs(dataList[-1] - dataList[0]) / bins

        for point in dataList:
            try:
                distbin[int(round(point/binLength))] += 1
            except:
                distbin[-1] += 1

        print binLength, int(round(point/binLength))

        bars.append(bar(ind2 + offset[index], distbin, width, color=colorList[index]))
        barsColors.append(bars[-1][0])

        print distbin
        halfCount = sum(distbin) / 2
        median = 0
        foundMedian = False
        while not foundMedian:
            if sum(distbin[:median]) < halfCount:
                median += 1
            else:
                foundMedian = True

        print median
        index += 1

    xlim(-1 * width - 0.2, bins + 0.2)

    if len(barsLegend) > 0:
        legend(barsColors, barsLegend)

    ylabel(yLabel)
    xlabel(xLabel)

    if doLabels:
        setp(gca(), "xticklabels", binLabels)

    if maxY > 0:
        ylim(0, maxY)

    if len(figTitle) > 0:
        title(figTitle)

    gca().get_xaxis().tick_bottom()
    gca().get_yaxis().tick_left()

    savefig(pngfilename)


if __name__ == "__main__":
    main(sys.argv)