#
#  geneLocusBins.py
#  ENRAGE
#

# originally from version 1.3 of geneDownstreamBins.py
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import string
from commoncode import getMergedRegions, getLocusByChromDict, computeRegionBins, getConfigParser, getConfigIntOption, getConfigOption, getConfigBoolOption
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict

print "geneLocusBins: version 2.2"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome rdsfile outfilename [--bins numbins] [--flank bp] [--upstream bp] [--downstream bp] [--nocds] [--regions acceptfile] [--cache] [--raw] [--force]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[2]
   
    upstreamBp = 0
    downstreamBp = 0
    doFlank = False
    if options.flankBP is not None:
        upstreamBp = options.flankBP
        downstreamBp = options.flankBP
        doFlank = True

    if options.upstreamBP is not None:
        upstreamBp = options.upstreamBP
        doFlank = True

    if options.downstreamBP is not None:
        downstreamBp = options.downstreamBP
        doFlank = True

    geneLocusBins(genome, hitfile, outfilename, upstreamBp, downstreamBp, doFlank, options.normalizeBins, options.doCache, options.bins, options.doCDS, options.limitNeighbor, options.acceptfile)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--bins", type="int", dest="bins",
                      help="number of bins to use [default: 10]")
    parser.add_option("--flank", type="int", dest="flankBP",
                      help="number of flanking BP on both upstream and downstream [default: 0]")
    parser.add_option("--upstream", type="int", dest="upstreamBP",
                      help="number of upstream flanking BP [default: 0]")
    parser.add_option("--downstream", type="int", dest="downstreamBP",
                      help="number of downstream flanking BP [default: 0]")
    parser.add_option("--nocds", action="store_false", dest="doCDS",
                      help="do not CDS")
    parser.add_option("--raw", action="store_false", dest="normalizeBins",
                      help="do not normalize results")
    parser.add_option("--force", action="store_false", dest="limitNeighbor",
                      help="limit neighbor region")
    parser.add_option("--regions", dest="acceptfile")
    parser.add_option("--cache", action="store_true", dest="doCache",
                      help="use cache")

    configParser = getConfigParser()
    section = "geneLocusBins"
    normalizeBins = getConfigBoolOption(configParser, section, "normalizeBins", True)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    bins = getConfigIntOption(configParser, section, "bins", 10)
    flankBP = getConfigOption(configParser, section, "flankBP", None)
    upstreamBP = getConfigOption(configParser, section, "upstreamBP", None)
    downstreamBP = getConfigOption(configParser, section, "downstreamBP", None)
    doCDS = getConfigBoolOption(configParser, section, "doCDS", True)
    limitNeighbor = getConfigBoolOption(configParser, section, "limitNeighbor", True)

    parser.set_defaults(normalizeBins=normalizeBins, doCache=doCache, bins=bins, flankBP=flankBP,
                        upstreamBP=upstreamBP, downstreamBP=downstreamBP, doCDS=doCDS,
                        limitNeighbor=limitNeighbor)

    return parser


def geneLocusBins(genome, hitfile, outfilename, upstreamBp=0, downstreamBp=0, doFlank=False,
                  normalizeBins=True, doCache=False, bins=10, doCDS=True, limitNeighbor=True,
                  acceptfile=None):


    hg = Genome(genome)
    gidList = hg.allGIDs()
    gidList.sort()
    if acceptfile is None:
        acceptDict = {}
    else:
        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)
        for chrom in acceptDict:
            for region in acceptDict[chrom]:
                if region.label not in gidList:
                    gidList.append(region.label)

    if doFlank:
        locusByChromDict = getLocusByChromDict(hg, upstream=upstreamBp, downstream=downstreamBp, useCDS=doCDS, additionalRegionsDict=acceptDict, keepSense=True, adjustToNeighbor=limitNeighbor)
    else:
        locusByChromDict = getLocusByChromDict(hg, additionalRegionsDict=acceptDict, keepSense=True)

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)
    readlen = hitRDS.getReadSize()
    normalizationFactor = 1.0
    if normalizeBins:
        totalCount = len(hitRDS)
        normalizationFactor = totalCount / 1000000.

    (gidBins, gidLen) = computeRegionBins(locusByChromDict, hitDict, bins, readlen, gidList, normalizationFactor, defaultRegionFormat=False)
    geneinfoDict = getGeneInfoDict(genome, cache=doCache)
    writeBins(gidList, geneinfoDict, gidBins, gidLen, outfilename, normalizeBins)


def writeBins(gidList, geneinfoDict, gidBins, gidLen, outfilename, normalizeBins=True):
    outfile = open(outfilename, "w")
    for gid in gidList:
        if "FAR" not in gid:
            symbol = "LOC%s" % gid
            geneinfo = ""
            try:
                geneinfo = geneinfoDict[gid]
                symbol = geneinfo[0][0]
            except KeyError:
                pass
        else:
            symbol = gid

        if gid in gidBins and gid in gidLen:
            tagCount = 0.
            for binAmount in gidBins[gid]:
                tagCount += binAmount

        outputList = [gid, symbol, tagCount, gidLen[gid]]
        for binAmount in gidBins[gid]:
            if normalizeBins:
                try:
                    normalizedValue = 100. * binAmount / tagCount
                except ZeroDivisionError:
                    normalizedValue = 100. * binAmount

                binAmount = normalizedValue

            outputList.append("%.1f" % binAmount)

        outLine = string.join(outputList, "\t")
        outfile.write("%s\n" % outLine)

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)
