#
#  geneLocusPeaks.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, findPeak, getLocusByChromDict
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption


print "geneLocusPeaks: version 2.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome rdsfile outfilename [--up upstream] [--down downstream] [--regions acceptfile] [--raw]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        print "\twhere upstream and downstream are in bp and and optional"
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[2]

    geneLocusPeaks(genome, hitfile, outfilename, options.upstream, options.downstream, options.acceptfile, options.normalize, options.doCache)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--up", type="int", dest="upstream")
    parser.add_option("--down", type="int", dest="downstream")
    parser.add_option("--regions", dest="acceptfile")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--cache", action="store_true", dest="doCache")

    configParser = getConfigParser()
    section = "geneLocusPeaks"
    upstream = getConfigIntOption(configParser, section, "upstream", 0)
    downstream = getConfigIntOption(configParser, section, "downstream", 0)
    acceptfile = getConfigOption(configParser, section, "acceptfile", "")
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)

    parser.set_defaults(upstream=upstream, downstream=downstream, acceptfile=acceptfile, normalize=normalize, doCache=doCache)

    return parser


def geneLocusPeaks(genome, hitfile, outfilename, upstream=0, downstream=0, acceptfile="", normalize=True, doCache=False):
    acceptDict = {}

    if acceptfile:
        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)

    print "upstream = %d downstream = %d" % (upstream, downstream)

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose = True, cache=doCache)
    readlen = hitRDS.getReadSize()
    normalizationFactor = 1.0
    if normalize:
        totalCount = len(hitRDS)
        normalizationFactor = totalCount / 1000000.

    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)

    hg = Genome(genome)
    gidCount = {}
    gidPos = {}
    geneinfoDict = getGeneInfoDict(genome, cache=True)
    locusByChromDict = getLocusByChromDict(hg, upstream, downstream, useCDS=True, additionalRegionsDict=acceptDict)

    gidList = hg.allGIDs()
    gidList.sort()
    for chrom in acceptDict:
        for region in acceptDict[chrom]:
            if region.label not in gidList:
                gidList.append(region.label)

    for gid in gidList:
        gidCount[gid] = 0

    for chrom in hitDict:
        if chrom not in locusByChromDict:
            continue

        print chrom
        for (start, stop, gid, glen) in locusByChromDict[chrom]:
            gidCount[gid] = 0.
            peak = findPeak(hitDict[chrom], start, glen, readlen)
            if len(peak.topPos) > 0:
                gidCount[gid] = peak.smoothArray[peak.topPos[0]]
                gidPos[gid] = (chrom, start + peak.topPos[0])
            else:
                gidPos[gid] = (chrom, start)

    outfile = open(outfilename, "w")

    for gid in gidList:
        if "FAR" not in gid:
            symbol = "LOC" + gid
            geneinfo = ""
            try:
                geneinfo = geneinfoDict[gid]
                symbol = geneinfo[0][0]
            except:
                pass
        else:
            symbol = gid

        if gid in gidCount and gid in gidPos:
            (chrom, pos) = gidPos[gid]
            outfile.write("%s\t%s\tchr%s\t%d\t%.2f\n" % (gid, symbol, chrom, pos, gidCount[gid]))

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)