#
#  geneUpstreamBins.py
#  ENRAGE
#
# originally from version 1.3 of geneDownstreamBins.py
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict, getConfigParser, getConfigBoolOption, getConfigIntOption

print "geneUpstreamBins: version 2.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome rdsfile outfilename [--max regionSize] [--raw] [--cache]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[3]

    geneUpstreamBins(genome, hitfile, outfilename, options.standardMinDist, options.normalize, options.doCache)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--raw", action="store_false", dest="normalize",
                       help="maximum region in bp")
    parser.add_option("--max", type="int", dest="standardMinDist")
    parser.add_option("--cache", action="store_true", dest="doCache")

    configParser = getConfigParser()
    section = "geneUpstreamBins"
    standardMinDist = getConfigIntOption(configParser, section, "regionSize", 3000)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)

    parser.set_defaults(standardMinDist=standardMinDist, normalize=normalize, doCache=doCache)

    return parser


def geneUpstreamBins(genome, hitfile, outfilename, standardMinDist=3000, normalize=True, doCache=False):
    bins = 10
    standardMinThresh = standardMinDist / bins

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose = True, cache=doCache)
    normalizationFactor = 1.0
    if normalize:
        totalCount = len(hitRDS)
        normalizationFactor = totalCount / 1000000.

    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)

    hg = Genome(genome)
    geneinfoDict = getGeneInfoDict(genome, cache=True)
    featuresDict = hg.getallGeneFeatures()

    outfile = open(outfilename,"w")

    gidList = hg.allGIDs()
    gidList.sort()
    for gid in gidList:
        symbol = "LOC" + gid
        geneinfo = ""
        featureList = []
        try:
            geneinfo = geneinfoDict[gid]
            featureList = featuresDict[gid]
            symbol = geneinfo[0][0]
        except:
            print geneinfo

        newfeatureList = []
        if len(featureList) == 0:
            continue

        for (ftype, chrom, start, stop, fsense) in featureList:
            if (start, stop) not in newfeatureList:
                newfeatureList.append((start, stop))

        if chrom not in hitDict:
            continue

        newfeatureList.sort()
        if len(newfeatureList) < 1:
            continue

        glen = standardMinDist
        if fsense == "F":
            nextGene = hg.leftGeneDistance((genome, gid), glen * 2)
            if nextGene < glen * 2:
                glen = nextGene / 2

            if glen < 1:
                glen = 1

            gstart = newfeatureList[0][0] - glen
            if gstart < 0:
                gstart = 0

        else:
            nextGene = hg.rightGeneDistance((genome, gid), glen * 2)
            if nextGene < glen * 2:
                glen = nextGene / 2

            if glen < 1:
                glen = 1

            gstart = newfeatureList[-1][1]

        tagCount = 0
        if glen < standardMinDist:
            continue

        binList = [0] * bins
        for read in hitDict[chrom]:
            tagStart = read["start"]
            weight = read["weight"]
            tagStart -= gstart
            if tagStart >= glen:
                break

            if tagStart > 0:
                tagCount += weight
                if fsense == "R":
                    # we are relying on python's integer division quirk
                    binID = tagStart / standardMinThresh 
                    binList[binID] += weight
                else:
                    rdist = glen - tagStart
                    binID = rdist / standardMinThresh 
                    binList[binID] += weight

        if tagCount < 2:
            continue

        print "%s %s %d %d %s" % (gid, symbol, normalizationFactor * tagCount, glen, str(binList))
        outfile.write("%s\t%s\t%d\t%d" % (gid, symbol, normalizationFactor * tagCount, glen))
        for binAmount in binList:
            outfile.write("\t%d" % binAmount)
        outfile.write("\n")

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)