#
#  geneLocusBins.py
#  ENRAGE
#

# originally from version 1.3 of geneDownstreamBins.py
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, getLocusByChromDict, computeRegionBins, getConfigParser, getConfigIntOption, getConfigOption, getConfigBoolOption
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict

print "geneLocusBins: version 2.2"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome rdsfile outfilename [--bins numbins] [--flank bp] [--upstream bp] [--downstream bp] [--nocds] [--regions acceptfile] [--cache] [--raw] [--force]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[2]
   
    upstreamBp = 0
    downstreamBp = 0
    doFlank = False
    if options.flankBP is not None:
        upstreamBp = options.flankBP
        downstreamBp = options.flankBP
        doFlank = True

    if options.upstreamBP is not None:
        upstreamBp = options.upstreamBP
        doFlank = True

    if options.downstreamBP is not None:
        downstreamBp = options.downstreamBP
        doFlank = True

    geneLocusBins(genome, hitfile, outfilename, upstreamBp, downstreamBp, doFlank, options.normalizeBins, options.doCache, options.bins, options.doCDS, options.limitNeighbor, options.acceptfile)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--bins", type="int", dest="bins",
                      help="number of bins to use [default: 10]")
    parser.add_option("--flank", type="int", dest="flankBP",
                      help="number of flanking BP on both upstream and downstream [default: 0]")
    parser.add_option("--upstream", type="int", dest="upstreamBP",
                      help="number of upstream flanking BP [default: 0]")
    parser.add_option("--downstream", type="int", dest="downstreamBP",
                      help="number of downstream flanking BP [default: 0]")
    parser.add_option("--nocds", action="store_false", dest="doCDS",
                      help="do not CDS")
    parser.add_option("--raw", action="store_false", dest="normalizeBins",
                      help="do not normalize results")
    parser.add_option("--force", action="store_false", dest="limitNeighbor",
                      help="limit neighbor region")
    parser.add_option("--regions", dest="acceptfile")
    parser.add_option("--cache", action="store_true", dest="doCache",
                      help="use cache")

    configParser = getConfigParser()
    section = "geneLocusBins"
    normalizeBins = getConfigBoolOption(configParser, section, "normalizeBins", True)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    bins = getConfigIntOption(configParser, section, "bins", 10)
    flankBP = getConfigOption(configParser, section, "flankBP", None)
    upstreamBP = getConfigOption(configParser, section, "upstreamBP", None)
    downstreamBP = getConfigOption(configParser, section, "downstreamBP", None)
    doCDS = getConfigBoolOption(configParser, section, "doCDS", True)
    limitNeighbor = getConfigBoolOption(configParser, section, "limitNeighbor", True)

    parser.set_defaults(normalizeBins=normalizeBins, doCache=doCache, bins=bins, flankBP=flankBP,
                        upstreamBP=upstreamBP, downstreamBP=downstreamBP, doCDS=doCDS,
                        limitNeighbor=limitNeighbor)

    return parser

def geneLocusBins(genome, hitfile, outfilename, upstreamBp=0, downstreamBp=0, doFlank=False,
                  normalizeBins=True, doCache=False, bins=10, doCDS=True, limitNeighbor=True,
                  acceptfile=None):

    if acceptfile is None:
        acceptDict = {}
    else:
        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose = True, cache=doCache)
    readlen = hitRDS.getReadSize()
    normalizationFactor = 1.0
    if normalizeBins:
        totalCount = len(hitRDS)
        normalizationFactor = totalCount / 1000000.

    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)

    hg = Genome(genome)
    geneinfoDict = getGeneInfoDict(genome, cache=doCache)
    if doFlank:
        locusByChromDict = getLocusByChromDict(hg, upstream=upstreamBp, downstream=downstreamBp, useCDS=doCDS, additionalRegionsDict=acceptDict, keepSense=True, adjustToNeighbor = limitNeighbor)
    else:
        locusByChromDict = getLocusByChromDict(hg, additionalRegionsDict=acceptDict, keepSense=True)

    gidList = hg.allGIDs()
    gidList.sort()
    for chrom in acceptDict:
        for region in acceptDict[chrom]:
            if region.label not in gidList:
                gidList.append(region.label)

    (gidBins, gidLen) = computeRegionBins(locusByChromDict, hitDict, bins, readlen, gidList, normalizationFactor, defaultRegionFormat=False)

    outfile = open(outfilename,'w')

    for gid in gidList:
        if 'FAR' not in gid:
            symbol = 'LOC' + gid
            geneinfo = ''
            try:
                geneinfo = geneinfoDict[gid]
                symbol = geneinfo[0][0]
            except:
                pass
        else:
            symbol = gid
        if gid in gidBins and gid in gidLen:
            tagCount = 0.
            for binAmount in gidBins[gid]:
                tagCount += binAmount
        outfile.write('%s\t%s\t%.1f\t%d' % (gid, symbol, tagCount, gidLen[gid]))
        for binAmount in gidBins[gid]:
            if normalizeBins:
                if tagCount == 0:
                    tagCount = 1
                outfile.write('\t%.1f' % (100. * binAmount / tagCount))
            else:
                outfile.write('\t%.1f' % binAmount)
        outfile.write('\n')
    outfile.close()


if __name__ == "__main__":
    main(sys.argv)