try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import ReadDataset
from commoncode import getConfigParser, getConfigBoolOption, getConfigFloatOption

print "normalizeFinalExonic: version 3.6"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog rdsfile expandedRPKMfile multicountfile outfile [--multifraction] [--multifold] [--minrpkm minThreshold] [--cache] [--withGID]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        print usage
        sys.exit(1)

    rdsfilename = args[0]
    expandedRPKMfile = args[1]
    multicountfile = args[2]
    outfilename = args[3]

    normalizeFinalExonic(rdsfilename, expandedRPKMfile, multicountfile, outfilename,
                         options.reportFraction, options.reportFold, options.minThreshold,
                         options.doCache, options.writeGID)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--multifraction", action="store_true", dest="reportfraction")
    parser.add_option("--multifold", action="store_true", dest="reportFold")
    parser.add_option("--minrpkm", type="float", dest="minThreshold")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--withGID", action="store_true", dest="writeGID")

    configParser = getConfigParser()
    section = "normalizeFinalExonic"
    reportFraction = getConfigBoolOption(configParser, section, "multifraction", False)
    reportFold = getConfigBoolOption(configParser, section, "reportFold", False)
    minThreshold = getConfigFloatOption(configParser, section, "minThreshold", 0.)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    writeGID = getConfigBoolOption(configParser, section, "writeGID", False)

    parser.set_defaults(reportFraction=reportFraction, reportFold=reportFold, minThreshold=minThreshold,
                        doCache=doCache, writeGID=writeGID)

    return parser


def normalizeFinalExonic(rdsfilename, expandedRPKMfilename, multicountfilename, outfilename,
                         reportFraction=False, reportFold=False, minThreshold=0., doCache=False,
                         writeGID=False):

    expandedRPKMfile = open(expandedRPKMfilename)
    multicountfile = open(multicountfilename)

    if reportFraction:
        print "reporting fractional contribution of multireads"
        reportFold = False
    elif reportFold:
        print "reporting fold contribution of multireads"

    RDS = ReadDataset.ReadDataset(rdsfilename, verbose=True, cache=doCache, reportCount=False)
    uniqcount = RDS.getUniqsCount()
    splicecount = RDS.getSplicesCount()
    multicount = RDS.getMultiCount()
    countDict = {}
    multicountDict = {}
    lengthDict = {}
    gidList = []

    uniqspliceCount = (uniqcount + splicecount) / 1000000.
    totalCount = (uniqcount + splicecount + multicount) / 1000000.

    symbolDict = {}

    for line in expandedRPKMfile:
        fields = line.strip().split()
        lineGID = fields[0]
        symbolDict[lineGID] = fields[1]
        countDict[lineGID] = float(fields[-1]) * float(fields[-2]) * uniqspliceCount
        lengthDict[lineGID] = float(fields[-2])
        multicountDict[lineGID] = 0
        if lineGID not in gidList:
            gidList.append(lineGID)

    expandedRPKMfile.close()

    for line in multicountfile:
        fields = line.strip().split()
        gid = fields[0]
        if gid in countDict:
            countDict[gid] += float(fields[-1])
            multicountDict[gid] = float(fields[-1])
        else:
            print "could not find gid %s in dictionaries" % gid

    multicountfile.close()

    outfile = open(outfilename, "w")
    outheader = "#"
    if writeGID:
        outheader += "GID\t"

    outheader += "gene\tlen_kb\tRPKM"
    if reportFraction:
        outheader += "\tmulti/all"
    elif reportFold:
        outheader += "\tall/uniq"
        
    outheader += "\n"
    outfile.write(outheader)

    outlineList = []
    index = 0
    for gid in gidList:
        outline = ""
        gene = symbolDict[gid]
        rpm = countDict[gid] / totalCount
        rpkm = rpm / lengthDict[gid]
        if rpkm < minThreshold:
            continue

        if writeGID:
            outline = "%s\t" % gid

        index += 1
        try:
            multirpm = multicountDict[gid] / totalCount
            multirpkm = multirpm / lengthDict[gid]
        except:
            print "problem with %s - skipping " % gid
            continue

        if reportFraction or reportFold:
            try:
                if reportFraction:
                    multivalue = multirpkm / rpkm
                else:
                    if rpm > multirpm:
                        uniqrpkm = (rpm - multirpm) / lengthDict[gid]
                        multivalue = rpkm / uniqrpkm
                    elif rpkm > 0.01:
                        multivalue = 100.
                    else:
                        multivalue = 1.0
            except:
                multivalue = 0

            outline += "%s\t%.3f\t%.2f\t%.2f\n" %  (gene, lengthDict[gid], rpkm, multivalue)
            outlineList.append((rpkm, outline))
        else:
            outline += "%s\t%.3f\t%.2f\n" %  (gene, lengthDict[gid], rpkm)
            outlineList.append((rpkm, outline))

    outlineList.sort()
    outlineList.reverse()

    for (rpkm, line) in outlineList:
        outfile.write(line)

    outfile.close()

    print "returned %d genes" % index


if __name__ == "__main__":
    main(sys.argv)