#
#  profilebins.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption, getConfigFloatOption

print "profilebins: version 2.3"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog label infile1 [--upstream infile2] [--downstream infile3] [--uplength kb] [--downlength kb] [--gene geneName] [--genes genefile] [--append] outfile"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    label = args[0]
    infilename = args[1]
    outfilename = args[2]

    profilebins(label, infilename, outfilename, options.upfilename, options.downfilename,
                options.uplength, options.downlength, options.gene, options.genefile,
                options.doAppend)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--upstream", dest="upfilename")
    parser.add_option("--downstream", dest="downfilename")
    parser.add_option("--uplength", type="float", dest="uplength")
    parser.add_option("--downlength", type="int", dest="")
    parser.add_option("--gene", dest="gene")
    parser.add_option("--genes", dest="genefile")
    parser.add_option("--append", action="store_true", dest="doAppend")

    configParser = getConfigParser()
    section = "profilebins"
    upfilename = getConfigOption(configParser, section, "upfilename", None)
    downfilename = getConfigOption(configParser, section, "downfilename", None)
    uplength = getConfigFloatOption(configParser, section, "uplength", 0.0)
    downlength = getConfigFloatOption(configParser, section, "downlength", 0.0)
    gene = getConfigOption(configParser, section, "gene", None)
    genefile = getConfigOption(configParser, section, "genefile", None)
    doAppend = getConfigBoolOption(configParser, section, "doAppend", False)

    parser.set_defaults(upfilename=upfilename, downfilename=downfilename, uplength=uplength, downlength=downlength,
                        gene=gene, genefile=genefile, doAppend=doAppend)

    return parser


def profilebins(label, infilename, outfilename, upfilename=None, downfilename=None,
                uplength=0.0, downlength=0.0, gene=None, genefile=None, doAppend=False):

    fileList = [infilename]
    geneList = []
    restrictGenes = False
    if gene is not None:
        geneList.append(gene)
        restrictGenes = True

    if genefile is not None:
        for line in genefile:
            fields = line.strip().split()
            if len(fields) > 1:
                geneList.append(fields[0])
            else:
                geneList.append(line.strip())

        restrictGenes = True

    if upfilename is not None:
        fileList = [upfilename, infilename]

    if downfilename is not None:
        fileList.append(downfilename)

    partLength = [10.]
    partOffset = [0.]

    if uplength:
        partLength = [uplength, 10.]
        partOffset = [-1. * uplength, 0.]

    if downlength:
        partLength.append(downlength)
        partOffset.append(10.)

    totalWeight = 0.
    totalBins = []
    for afile in fileList:   
        infile = open(afile)

        line = infile.readline()
        fields = line.strip().split()
        numBins = len(fields) - 4

        geneName = fields[1]
        weight = float(fields[2])
        if restrictGenes and geneName in geneList:
            totalWeight += weight

        totalBins.append([])
        for myBin in fields[4:]:
            if not restrictGenes or (restrictGenes and geneName in geneList):
                totalBins[-1].append(weight * float(myBin))
            else:
                totalBins[-1].append(0.)

        for line in infile:
            fields = line.strip().split()
            geneName = fields[1]
            if restrictGenes and geneName not in geneList:
                continue

            weight = float(fields[2])
            index = 0
            for myBin in fields[4:]:
                totalBins[-1][index] += weight * float(myBin)
                index += 1

            totalWeight += weight

    sumWeight = 0.
    totalPercent = 0.
    if doAppend:
        outfile = open(outfilename, "a")
    else:
        outfile = open(outfilename, "w")
        outfile.write("x-axis")
        partIndex = 0
        for partBins in totalBins:
            partLen = partLength[partIndex]
            numBins = len(partBins)
            for binIndex in range(numBins):
                outfile.write("\t%.2f" % (partOffset[partIndex] + (binIndex * partLen/numBins)))

            partIndex += 1

        outfile.write("\tweight\n")

    outfile.write(label)
    for partBins in totalBins:
        for aBin in partBins:
            percent = aBin / totalWeight
            outfile.write("\t%.1f" % percent)
            sumWeight += aBin
            totalPercent += percent

    outfile.write("\t%.1f\n" % totalWeight)
    outfile.close()

    print sumWeight
    print totalPercent


if __name__ == "__main__":
    main(sys.argv)