#
#  regionCounts.py
#  ENRAGE
#

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

import sys
import string
import optparse
from commoncode import getMergedRegions, findPeak, writeLog, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption
import ReadDataset

versionString = "regionCounts: version 3.10"
print versionString

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog regionfile rdsfile outfilename [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    regionfilename = args[0]
    hitfile =  args[1]
    outfilename = args[2]

    regionCounts(regionfilename, hitfile, outfilename, options.flagRDS, options.cField,
                 options.useFullchrom, options.normalize, options.padregion,
                 options.mergeregion, options.merging, options.doUniqs, options.doMulti,
                 options.doSplices, options.usePeak, options.cachePages, options.logfilename,
                 options.doRPKM, options.doLength, options.forceRegion)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--markRDS", action="store_true", dest="flagRDS")
    parser.add_option("--chromField", type="int", dest="cField")
    parser.add_option("--fullchrom", action="store_true", dest="useFullchrom")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--padregion", type="int", dest="padregion")
    parser.add_option("--mergeregion", type="int", dest="mergeregion")
    parser.add_option("--nomerge", action="store_false", dest="merging")
    parser.add_option("--noUniqs", action="store_false", dest="doUniqs")
    parser.add_option("--noMulti", action="store_false", dest="doMulti")
    parser.add_option("--splices", action="store_true", dest="doSplices")
    parser.add_option("--peak", action="store_true", dest="usePeak")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--log", dest="logfilename")
    parser.add_option("--rpkm", action="store_true", dest="doRPKM")
    parser.add_option("--length", action="store_true", dest="doLength")
    parser.add_option("--force", action="store_true", dest="forceRegion")

    configParser = getConfigParser()
    section = "regionCounts"
    flagRDS = getConfigBoolOption(configParser, section, "flagRDS", False)
    cField = getConfigIntOption(configParser, section, "cField", 1)
    useFullchrom = getConfigBoolOption(configParser, section, "useFullchrom", False)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    padregion = getConfigIntOption(configParser, section, "padregion", 0)
    mergeregion = getConfigIntOption(configParser, section, "mergeregion", 0)
    merging = getConfigBoolOption(configParser, section, "merging", True)
    doUniqs = getConfigBoolOption(configParser, section, "doUniqs", True)
    doMulti = getConfigBoolOption(configParser, section, "doMulti", True)
    doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
    usePeak = getConfigBoolOption(configParser, section, "usePeak", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", -1)
    logfilename = getConfigOption(configParser, section, "logfilename", "regionCounts.log")
    doRPKM = getConfigBoolOption(configParser, section, "doRPKM", False)
    doLength = getConfigBoolOption(configParser, section, "doLength", False)
    forceRegion = getConfigBoolOption(configParser, section, "forceRegion", False)

    parser.set_defaults(flagRDS=flagRDS, cField=cField, useFullchrom=useFullchrom, normalize=normalize,
                        padregion=padregion, mergeregion=mergeregion, merging=merging, doUniqs=doUniqs,
                        doMulti=doMulti, doSplices=doSplices, usePeak=usePeak, cachePages=cachePages,
                        logfilename=logfilename, doRPKM=doRPKM, doLength=doLength,
                        forceRegion=forceRegion)

    return parser


def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
                 useFullchrom=False, normalize=True, padregion=0, mergeregion=0,
                 merging=True, doUniqs=True, doMulti=True, doSplices=False, usePeak=False,
                 cachePages=-1, logfilename="regionCounts.log", doRPKM=False, doLength=False,
                 forceRegion=False):

    print "padding %d bp on each side of a region" % padregion
    print "merging regions closer than %d bp" % mergeregion
    print "will use peak values"

    if cachePages != -1:
        doCache = True
    else:
        doCache = False

    normalize = True
    doRPKM = False
    if doRPKM == True:
        normalize = True

    writeLog(logfilename, versionString, string.join(sys.argv[1:]))

    regionDict = getMergedRegions(regionfilename, maxDist=mergeregion, minHits=-1, keepLabel=True,
                                  fullChrom=useFullchrom, verbose=True, chromField=cField,
                                  doMerge=merging, pad=padregion)

    labelList = []
    labeltoRegionDict = {}
    regionCount = {}

    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
    readlen = hitRDS.getReadSize()
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    totalCount = len(hitRDS)
    if normalize:
        normalizationFactor = totalCount / 1000000.

    chromList = hitRDS.getChromosomes(fullChrom=useFullchrom)
    if len(chromList) == 0 and doSplices:
        chromList = hitRDS.getChromosomes(table="splices", fullChrom=useFullchrom)

    chromList.sort()

    if flagRDS:
        hitRDS.setSynchronousPragma("OFF")        

    for rchrom in regionDict:
        if forceRegion and rchrom not in chromList:
            print rchrom
            for region in regionDict[rchrom]:
                regionCount[region.label] = 0
                labelList.append(region.label)
                labeltoRegionDict[region.label] = (rchrom, region.start, region.stop)

    for rchrom in chromList:
        regionList = []
        if rchrom not in regionDict:
            continue

        print rchrom
        if useFullchrom:
            fullchrom = rchrom
        else:
            fullchrom = "chr%s" % rchrom

        if usePeak:
            readDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True, findallOptimize=True)
            rindex = 0
            dictLen = len(readDict[fullchrom])

        for region in regionDict[rchrom]:
            label = region.label
            start = region.start
            stop = region.stop
            regionCount[label] = 0
            labelList.append(label)
            labeltoRegionDict[label] = (rchrom, start, stop)
            regionList.append((label, fullchrom, start, stop))
            if usePeak:
                readList = []
                for localIndex in xrange(rindex, dictLen):
                    read = readDict[fullchrom][localIndex]
                    if read["start"] < start:
                        rindex += 1
                    elif start <= read["start"] <= stop:
                        readList.append(read)
                    else:
                        break

                if len(readList) < 1:
                    continue

                readList.sort()
                peak = findPeak(readList, start, stop - start, readlen, doWeight=True)
                try:
                    topValue = peak.smoothArray[peak.topPos[0]]
                except:
                    print "problem with %s %s" % (str(peak.topPos), str(peak.smoothArray))
                    continue

                regionCount[label] += topValue
            else:
                regionCount[label] += hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)

        if flagRDS:
            hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices)

    if flagRDS:
        hitRDS.setSynchronousPragma("ON")    

    if normalize:
        for label in regionCount:
            regionCount[label] = float(regionCount[label]) / normalizationFactor

    outfile = open(outfilename, "w")

    if forceRegion:
        labelList.sort()

    for label in labelList:
        (chrom, start, stop) = labeltoRegionDict[label]
        if useFullchrom:
            fullchrom = chrom
        else:
            fullchrom = "chr%s" % chrom

        if normalize:
            if doRPKM:
                length = abs(stop - start) / 1000.
            else:
                length = 1.

            if length < 0.001:
                length = 0.001

            outfile.write("%s\t%s\t%d\t%d\t%.2f" % (label, fullchrom, start, stop, regionCount[label]/length))
            if doLength:
                outfile.write("\t%.1f" % length)
        else:
            outfile.write('%s\t%s\t%d\t%d\t%d' % (label, fullchrom, start, stop, regionCount[label]))

        outfile.write("\n")

    outfile.close()
    if doCache and flagRDS:
        hitRDS.saveCacheDB(hitfile)

    writeLog(logfilename, versionString, "returned %d region counts for %s (%.2f M reads)" % (len(labelList), hitfile, totalCount / 1000000.))


if __name__ == "__main__":
    main(sys.argv)