#
#  regionintersects.py
#  ENRAGE
#
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
from commoncode import getMergedRegions, findPeak, getConfigParser, getConfigOption, getConfigBoolOption
import ReadDataset

print "regionintersects: version 3.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog rdsfile1 regionfile1 rdsfile2 regionfile2 outfile [--reject1 File1] [--reject2 File2] [--union] [--cache] [--raw]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 5:
        print usage
        sys.exit(1)

    readOneName =  args[0]
    regionOneName = args[1]
    readTwoName = args[2]
    regionTwoName = args[3]
    outfilename = args[4]

    regionintersects(readOneName, regionOneName, readTwoName, regionTwoName,
                     outfilename, options.rejectOneName, options.rejectTwoName,
                     options.trackReject, options.doCache, options.normalize,
                     options.doVerbose)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--reject1", dest="rejectOneName")
    parser.add_option("--reject2", dest="rejectTwoName")
    parser.add_option("--union", action="store_true", dest="trackReject")
    parser.add_option("--cache", action="store_true", dest="doCache")
    parser.add_option("--raw", action="store_false", dest="normalize")
    parser.add_option("--verbose", action="store_true", dest="doVerbose")

    configParser = getConfigParser()
    section = "regionintersects"
    rejectOneName = getConfigOption(configParser, section, "rejectOneName", None)
    rejectTwoName = getConfigOption(configParser, section, "rejectTwoName", None)
    trackReject = getConfigBoolOption(configParser, section, "trackReject", False)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    normalize = getConfigBoolOption(configParser, section, "normalize", True)
    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)

    parser.set_defaults(rejectOneName=rejectOneName, rejectTwoName=rejectTwoName,
                        trackReject=trackReject, doCache=doCache, normalize=normalize,
                        doVerbose=doVerbose)

    return parser


def regionintersects(readOneName, regionOneName, readTwoName, regionTwoName,
                     outfilename, rejectOneName=None, rejectTwoName=None,
                     trackReject=False, doCache=False, normalize=True, doVerbose=False):

    mergedist=0

    outfile = open(outfilename, "w")

    doReject = False
    if rejectOneName is not None:
        trackReject = True
        doReject = True
        rejectOne = open(rejectOneName, "w")

    if rejectTwoName is not None:
        trackReject = True
        doReject = True
        rejectTwo = open(rejectTwoName, "w")

    oneDict = getMergedRegions(regionOneName, mergedist, verbose=doVerbose)
    twoDict = getMergedRegions(regionTwoName, mergedist, verbose=doVerbose)

    oneRDS = ReadDataset.ReadDataset(readOneName, verbose=doVerbose, cache=doCache) 
    twoRDS = ReadDataset.ReadDataset(readTwoName, verbose=doVerbose, cache=doCache)

    if normalize:
        normalize1 = len(oneRDS) / 1000000.
        normalize2 = len(twoRDS) / 1000000.
    else:
        normalize1 = 1.
        normalize2 = 1.

    commonRegions = 0
    oneRejectIndex = 0
    twoRejectIndex = 0

    onePeaksDict = {}
    oneFoundDict = {}

    numRegionsOne = 0
    numRegionsTwo = 0
    commonChromosomeList = set(oneDict.keys())
    for rchrom in oneDict:
        numRegionsOne += len(oneDict[rchrom])

    for rchrom in twoDict:
        commonChromosomeList.add(rchrom)
        numRegionsTwo += len(twoDict[rchrom])

    outfile.write("#%d\tregions in\t%s\n#%d\tregions in\t%s\n" % (numRegionsOne, regionOneName, numRegionsTwo, regionTwoName))

    for chromosome in commonChromosomeList:
        print chromosome
        rindex = 0
        rindex2 = 0
        fullchrom = "chr%s" % chromosome
        oneReads = oneRDS.getReadsDict(fullChrom=True, chrom=fullchrom, withWeight=True, doMulti=True)
        dictLen1 = len(oneReads[fullchrom])
        twoReads = twoRDS.getReadsDict(fullChrom=True, chrom=fullchrom, withWeight=True, doMulti=True)
        dictLen2 = len(twoReads[fullchrom])
        onePeaksDict[chromosome] = []
        oneFoundDict[chromosome] = []
        for region in oneDict[chromosome]:
            start = region.start
            stop = region.stop
            length = region.length
            readList = []
            for localIndex in xrange(rindex, dictLen1):
                read = oneReads[fullchrom][localIndex]
                if read["start"] < start:
                    rindex += 1
                elif start <= read["start"] <= stop:
                    readList.append(read)
                else:
                    break

            if len(readList) < 1:
                continue

            readList.sort()

            peak = findPeak(readList, start, length, doWeight=True)
            onePeakScore = peak.smoothArray[peak.topPos[0]]
            onePeaksDict[chromosome].append((peak.topPos[0] + start, length/2, start, stop, peak.numHits/normalize1, onePeakScore/normalize1))

        for region in twoDict[chromosome]:
            start = region.start
            stop = region.stop
            length = region.length
            readList2 = []
            for localIndex in xrange(rindex2, dictLen2):
                read = twoReads[fullchrom][localIndex]
                if read["start"] < start:
                    rindex2 += 1
                elif start <= read["start"] <= stop:
                    readList2.append(read)
                else:
                    break

            if len(readList2) < 1:
                continue

            readList2.sort()
            peak2 = findPeak(readList2, start, length, doWeight=True)
            numHits = peak2.numHits
            numHits /= normalize2
            twoIsCommon = False
            twoPeak = peak2.topPos[0] + start
            twoRadius = length/2
            twoPeakScore = peak2.smoothArray[peak2.topPos[0]] / normalize2
            for (onePeak, oneRadius, ostart, ostop, ohits, opeakScore) in onePeaksDict[chromosome]:
                if abs(twoPeak - onePeak) < (twoRadius + oneRadius):
                    if (onePeak, oneRadius, ostart, ostop, ohits) not in oneFoundDict:
                        oneFoundDict[chromosome].append((onePeak, oneRadius, ostart, ostop, ohits))

                    twoIsCommon = True
                    commonRegions += 1
                    outline = "common%d\tchr%s\t%d\t%d\t%.1f\t%.1f\tchr%s\t%d\t%d\t%.1f\t%.1f" % (commonRegions, chromosome, ostart, ostop, ohits, opeakScore, chromosome, start, stop, numHits, twoPeakScore)
                    if doVerbose:
                        print outline

                    print >> outfile, outline

            if trackReject and not twoIsCommon:
                twoRejectIndex += 1
                outline = "rejectTwo%d\tchr%s\t%d\t%d\t%.1f\t%.1f" % (twoRejectIndex, chromosome, start, stop, numHits, twoPeakScore)
                if doReject:
                    print >> rejectTwo, outline
                else:
                    print >> outfile, outline

                if doVerbose:
                    print outline

        if trackReject:
            for (onePeak, oneRadius, ostart, ostop, ohits, opeakScore) in onePeaksDict[chromosome]:
                if (onePeak, oneRadius, ostart, ostop, ohits) not in oneFoundDict[chromosome]:
                    oneRejectIndex += 1
                    outline = "rejectOne%d\tchr%s\t%d\t%d\t%.1f\t%.1f" % (oneRejectIndex, chromosome, ostart, ostop, ohits, opeakScore)
                    if doReject:
                        print >> rejectOne, outline
                    else:
                        print >> outfile, outline

                    if doVerbose:
                        print outline

    if trackReject:
        print "common: %d   one-only: %d   two-only: %d" % (commonRegions, oneRejectIndex, twoRejectIndex)
        outfile.write("#common: %d\tone-only: %d\ttwo-only: %d\n" % (commonRegions, oneRejectIndex, twoRejectIndex))
    else:
        print "common: %d" % commonRegions
        outfile.write("#common: %d\n" % commonRegions)

    outfile.close()


if __name__ == "__main__":
    main(sys.argv)