#
#  geneDownstreamBins.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    pass

# originally from version 1.3 of geneDnaDownstreamCounts.py
import sys
import optparse
import ReadDataset
from cistematic.genomes import Genome
from commoncode import getGeneInfoDict, getConfigParser, getConfigIntOption, ErangeError

print "geneDownstreamBins: version 2.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: %prog genome rdsfile outfilename [--max regionSize]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    genome = args[0]
    hitfile =  args[1]
    outfilename = args[2]

    geneDownstreamBins(genome, hitfile, outfilename, options.standardMinDist)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--max", type="int", dest="standardMinDist",
                      help="maximum region in bp")

    configParser = getConfigParser()
    section = "geneDownstreamBins"
    standardMinDist = getConfigIntOption(configParser, section, "regionSize", 3000)

    parser.set_defaults(standardMinDist=standardMinDist)

    return parser


def geneDownstreamBins(genome, hitfile, outfilename, standardMinDist=3000, doCache=False, normalize=False):

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    normalizationFactor = 1.0
    if normalize:
        hitDictSize = len(hitRDS)
        normalizationFactor = hitDictSize / 1000000.

    hitDict = hitRDS.getReadsDict(doMulti=True, findallOptimize=True)
    geneinfoDict = getGeneInfoDict(genome, cache=True)
    hg = Genome(genome)
    featuresDict = hg.getallGeneFeatures()
    outfile = open(outfilename, "w")
    gidList = hg.allGIDs()
    gidList.sort()
    for gid in gidList:
        try:
            featuresList = featuresDict[gid]
        except KeyError:
            print gid

        try:
            binList, symbol, geneLength, tagCount = getDownstreamBins(genome, gid, hitDict, geneinfoDict, featuresList, standardMinDist)
        except ErangeError:
            continue

        tagCount *= normalizationFactor
        print "%s %s %.2f %d %s" % (gid, symbol, tagCount, geneLength, str(binList))
        outfile.write("%s\t%s\t%.2f\t%d" % (gid, symbol, tagCount, geneLength))
        for binAmount in binList:
            outfile.write("\t%.2f" % binAmount)

        outfile.write("\n")

    outfile.close()


def getDownstreamBins(genome, gid, hitDict, geneinfoDict, featuresList, standardMinDist):

    symbol, featurePositionList, sense, chrom = getFeatureList(gid, geneinfoDict, featuresList, hitDict.keys())
    geneStart, geneLength = getGeneStats(genome, gid, standardMinDist, featurePositionList, sense)
    if geneLength < standardMinDist:
        raise ErangeError("gene length less than minimum")

    binList, tagCount = getBinList(hitDict[chrom], standardMinDist, geneStart, geneLength, sense)
    if tagCount < 2:
        raise ErangeError("tag count less than minimum")

    return binList, symbol, geneLength, tagCount


def getFeatureList(gid, geneinfoDict, featureList, chromosomeList):
    if len(featureList) == 0:
        raise ErangeError("no features found")

    symbol = "LOC%s" % gid
    geneinfo = ""
    try:
        geneinfo = geneinfoDict[gid]
        symbol = geneinfo[0][0]
    except KeyError:
        print gid

    newfeatureList = []
    for (ftype, chrom, start, stop, sense) in featureList:
        if (start, stop) not in newfeatureList:
            newfeatureList.append((start, stop))

    if len(newfeatureList) < 1:
        raise ErangeError("no features found")

    if chrom not in chromosomeList:
        raise ErangeError("chromosome not found in reads")

    newfeatureList.sort()

    return symbol, newfeatureList, sense, chrom


def getGeneStats(genome, gid, minDistance, featureList, sense):
    geneLength = minDistance
    if sense == "F":
        nextGene = genome.rightGeneDistance((genome.genome, gid), geneLength * 2)
        geneStart = featureList[-1][1]
    else:
        nextGene = genome.leftGeneDistance((genome.genome, gid), geneLength * 2)
        geneStart = max(featureList[0][0] - geneLength, 0)

    if nextGene < geneLength * 2:
        geneLength = nextGene / 2

    geneLength = max(geneLength, 1)

    return geneStart, geneLength


def getBinList(readList, standardMinDist, geneStart, geneLength, sense):
    tagCount = 0
    bins = 10
    standardMinThresh = standardMinDist / bins
    binList = [0.] * bins
    for read in readList:
        tagStart = read["start"]
        if tagStart >= geneLength:
            break

        tagStart -= geneStart
        weight = read["weight"]
        if tagStart > 0:
            tagCount += weight
            if sense == "F":
                # we are relying on python's integer division quirk
                binID = tagStart / standardMinThresh 
            else:
                rdist = geneLength - tagStart
                binID = rdist / standardMinThresh

            binList[binID] += weight

    return binList, tagCount

if __name__ == "__main__":
    main(sys.argv)