#
#  geneLocusCounts.py
#  ENRAGE
#
"""  usage: python geneLocusCounts genome readDB outfilename [upstream] [downstream] [--noCDS] [--spanTSS] [--locusLength bplength] [--regions acceptfile] [--noUniqs] [--multi] [--splices]
            where upstream and downstream are in bp and and optional
            using noCDS requires either upstream or downstream (but not both)
            to be nonzero. Using -locuslength will report the first bplength
            or the last bplength of the gene region depending on whether it
            is positive or negative.
            will by default only count the uniq reads (use -noUniqs to turn off)
            but can also count multi and splice reads given the appropriate flags
"""
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, getLocusByChromDict, getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict

print "geneLocusCounts: version 3.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome readDB outfilename [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print __doc__
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[2]

    upstream = 0
    downstream = 0
    try:
        upstream = int(args[3])
    except ValueError:
        pass
    except IndexError:
        pass

    try:
        if "-" not in args[3]:
            downstream = int(args[4])
    except ValueError:
        pass

    geneLocusCounts(genome, hitfile, outfilename, upstream, downstream, options.doUniqs,
                    options.doMulti, options.doSplices, options.useCDS, options.spanTSS,
                    options.bplength, options.acceptfile)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--noUniqs", action="store_false", dest="doUniqs",
                      help="do not count unique reads")
    parser.add_option("--multi", action="store_true", dest="doUniqs",
                      help="count multi reads")
    parser.add_option("--splices", action="store_true", dest="doUniqs",
                      help="count splice reads")
    parser.add_option("--spanTSS", action="store_true", dest="spanTSS")
    parser.add_option("--regions", dest="acceptfile")
    parser.add_option("--noCDS", action="store_false", dest="useCDS")
    parser.add_option("--locusLength", type="int", dest="bplength",
                      help="number of bases to report")

    configParser = getConfigParser()
    section = "geneLocusCounts"
    doUniqs = getConfigBoolOption(configParser, section, "doUniqs", True)
    doMulti = getConfigBoolOption(configParser, section, "doMulti", False)
    doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
    useCDS = getConfigBoolOption(configParser, section, "useCDS", True)
    spanTSS = getConfigBoolOption(configParser, section, "spanTSS", False)
    bplength = getConfigIntOption(configParser, section, "bplength", 0)
    acceptfile = getConfigOption(configParser, section, "acceptfile", "")

    parser.set_defaults(doUniqs=doUniqs, doMulti=doMulti, doSplices=doSplices,
                        useCDS=useCDS, spanTSS=spanTSS, bplength=bplength,
                        acceptfile=acceptfile)

    return parser


def geneLocusCounts(genome, hitfile, outfilename, upstream=0, downstream=0,
                    doUniqs=True, doMulti=False, doSplices=False, useCDS=True,
                    spanTSS=False, bplength=0, acceptfile=""):

    print "returning only up to %d bp from gene locus" % bplength
    print "upstream = %d downstream = %d useCDS = %s spanTSS = %s" % (upstream, downstream, useCDS, spanTSS)

    if acceptfile:
        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose = True)

    totalCount = hitRDS.getCounts(uniqs=doUniqs, multi=doMulti, splices=doSplices)

    hg = Genome(genome)
    gidDict = {}
    locusByChromDict = getLocusByChromDict(hg, upstream, downstream, useCDS, acceptDict, upstreamSpanTSS=spanTSS, lengthCDS=bplength)
    locusChroms = locusByChromDict.keys()
    chromList = hitRDS.getChromosomes(fullChrom=False)
    chromList.sort()
    for chrom in chromList:
        if doNotProcessChromosome(chrom, locusChroms):
            continue

        fullchrom = "chr%s" % chrom
        print fullchrom
        hitRDS.memSync(fullchrom, index=True)
        for (start, stop, gid, length) in locusByChromDict[chrom]:
            if not gidDict.has_key(gid):
                gidDict[gid] = {"count": 0, "length": length}

            gidDict[gid]["count"] += hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)

    outfile = open(outfilename, "w")

    totalCount /= 1000000.

    outfile.write("#gid\tsymbol\tgidCount\tgidLen\trpm\trpkm\n")
    gidList = gidDict.keys()
    gidList.sort()
    geneinfoDict = getGeneInfoDict(genome, cache=True)
    for gid in gidList:
        if "FAR" not in gid:
            symbol = "LOC%s" % gid
            geneinfo = ""
            try:
                geneinfo = geneinfoDict[gid]
                symbol = geneinfo[0][0]
            except (KeyError, IndexError):
                pass
        else:
            symbol = gid

        gidCount = gidDict[gid]["count"]
        gidLength = gidDict[gid]["length"]
        rpm  = gidCount / totalCount
        rpkm = 1000. * rpm / gidLength
        outfile.write("%s\t%s\t%d\t%d\t%2.2f\t%2.2f\n" % (gid, symbol, gidCount, gidLength, rpm, rpkm))

    outfile.close()


def doNotProcessChromosome(chrom, locusChroms):
    return chrom == "M" or chrom not in locusChroms


if __name__ == "__main__":
    main(sys.argv)