#
#
#  geneStallingBins.py
#  ENRAGE
#

# originally from geneLocusBins.py
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, computeRegionBins, getLocusByChromDict, getConfigParser, getConfigBoolOption, getConfigIntOption, getConfigOption
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict

print "geneStallingBins: version 1.4"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s genome rdsfile controlrdsfile outfilename [--upstream bp] [--downstream bp] [--regions acceptfile] [--cache] [--normalize] [--tagCount]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    controlfile = args[2]
    outfilename = args[3]

    geneStallingBins(genome, hitfile, controlfile, outfilename, options.upstreamBp,
                     options.downstreamBp, options.acceptfile, options.doCache,
                     options.normalize, options.doTagCount, options.bins)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--upstream", type="int", dest="upstreamBp")
    parser.add_option("--downstream", type="int", dest="downstreamBp")
    parser.add_option("--regions", dest="acceptfile")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--normalize", action="store_true", dest="normalize")
    parser.add_option("--tagCount", action="store_true", dest="doTagCount")
    parser.add_option("--bins", type="int", dest="bins")

    configParser = getConfigParser()
    section = "geneStallingBins"
    upstreamBp = getConfigIntOption(configParser, section, "upstreamBp", 300)
    downstreamBp = getConfigIntOption(configParser, section, "downstreamBp", 0)
    acceptfile = getConfigOption(configParser, section, "acceptfile", "")
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    normalize = getConfigBoolOption(configParser, section, "normalize", False)
    doTagCount = getConfigBoolOption(configParser, section, "doTagCount", False)
    bins = getConfigIntOption(configParser, section, "bins", 4)

    parser.set_defaults(upstreamBp=upstreamBp, downstreamBp=downstreamBp, acceptfile=acceptfile,
                        doCache=doCache, normalize=normalize, doTagCount=doTagCount, bins=bins)

    return parser


def geneStallingBins(genome, hitfile, controlfile, outfilename, upstreamBp=300,
                     downstreamBp=0, acceptfile="", doCache=False, normalize=False,
                     doTagCount=False, bins=4):

    acceptDict = {}
    if acceptfile:
        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)

    doCDS = True
    limitNeighbor = False

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    readlen = hitRDS.getReadSize()
    hitNormalizationFactor = 1.0
    if normalize:
        hitDictSize = len(hitRDS)
        hitNormalizationFactor = hitDictSize / 1000000.

    controlRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    controlNormalizationFactor = 1.0
    if normalize:
        controlDictSize = len(hitRDS)
        controlNormalizationFactor = controlDictSize / 1000000.

    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)
    controlDict = controlRDS.getReadsDict(doMulti=True, findallOptimize=True)

    hg = Genome(genome)
    geneinfoDict = getGeneInfoDict(genome, cache=doCache)
    locusByChromDict = getLocusByChromDict(hg, upstream=upstreamBp, downstream=downstreamBp, useCDS=doCDS, additionalRegionsDict=acceptDict, keepSense=True, adjustToNeighbor=limitNeighbor)

    gidList = hg.allGIDs()
    gidList.sort()
    for chrom in acceptDict:
        for region in acceptDict[chrom]:
            if region.label not in gidList:
                gidList.append(region.label)

    (gidBins, gidLen) = computeRegionBins(locusByChromDict, hitDict, bins, readlen, gidList, hitNormalizationFactor, defaultRegionFormat=False, binLength=upstreamBp)
    (controlBins, gidLen) = computeRegionBins(locusByChromDict, controlDict, bins, readlen, gidList, controlNormalizationFactor, defaultRegionFormat=False, binLength=upstreamBp)

    outfile = open(outfilename, "w")

    for gid in gidList:
        if "FAR" not in gid:
            symbol = "LOC" + gid
            geneinfo = ""
            try:
                geneinfo = geneinfoDict[gid]
                symbol = geneinfo[0][0]
            except:
                pass
        else:
            symbol = gid

        if gid in gidBins and gid in gidLen:
            tagCount = 0.
            controlCount = 0.
            for binAmount in gidBins[gid]:
                tagCount += binAmount

            for binAmount in controlBins[gid]:
                controlCount += abs(binAmount)

            diffCount = tagCount + controlCount
            if diffCount < 0:
                diffCount = 0

            outfile.write("%s\t%s\t%.1f\t%d" % (gid, symbol, diffCount, gidLen[gid]))
            if (gidLen[gid] - 3 * upstreamBp) < upstreamBp:
                outfile.write("\tshort\n")
                continue

            TSSbins = (tagCount * (gidBins[gid][0] + gidBins[gid][1]) + controlCount * (controlBins[gid][0] + controlBins[gid][1])) / (upstreamBp / 50.)
            finalbin = (tagCount * gidBins[gid][-1] + controlCount * controlBins[gid][-1]) / ((gidLen[gid] - 3. * upstreamBp) / 100.)
            if finalbin <= 0.:
                finalbin = 0.01

            if TSSbins < 0:
                TSSbins = 0

            ratio =  float(TSSbins)/float(finalbin)
            for binAmount in gidBins[gid]:
                if doTagCount:
                    binAmount = binAmount * tagCount / 100.

                if normalize:
                    if tagCount == 0:
                        tagCount = 1

                    outfile.write("\t%.1f" % (100. * binAmount / tagCount))
                else:
                    outfile.write("\t%.1f" % binAmount)

        outfile.write("\t%.2f\n" % ratio)

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)