#
#  farPairs.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 7/13/10.
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import time
import optparse
import string
import ReadDataset
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption, countDuplicatesInList

print "farPairs: version 1.4"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog rdsfile outfile bedfile [--verbose] [--cache numPages] [--minDist bp] [--maxDist bp] [--minCount count] [--label string]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        print "\tIs both slow and takes up large amount of RAM"
        sys.exit(1)

    rdsfile = args[0]
    outfilename = args[1]
    outbedname = args[2]

    farPairs(rdsfile, outfilename, outbedname, options.doVerbose,
             options.cachePages, options.minDist, options.maxDist, options.minCount,
             options.label)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--verbose", action="store_true", dest="doVerbose")
    parser.add_option("--minDist", type="int", dest="minDist")
    parser.add_option("--maxDist", type="int", dest="maxDist")
    parser.add_option("--minCount", type="int", dest="minCount")
    parser.add_option("--label", dest="label")

    configParser = getConfigParser
    section = "farPairs"
    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    minDist = getConfigIntOption(configParser, section, "minDist", 1000)
    maxDist = getConfigIntOption(configParser, section, "maxDist", 500000)
    minCount = getConfigIntOption(configParser, section, "minCount", 2)
    label = getConfigOption(configParser, section, "label", None)

    parser.set_defaults(doVerbose=doVerbose, cachePages=cachePages,
                        minDist=minDist, maxDist=maxDist, minCount=minCount, label=label)

    return parser


def farPairs(rdsfile, outfilename, outbedname, doVerbose=False,
             cachePages=None, minDist=1000, maxDist=500000, minCount=2, label=None):

    flagDict = processRDSFile(rdsfile, outbedname, minDist, maxDist, cachePages, label, doVerbose)

    outfile = open(outfilename, "w")
    for region in flagDict:
        regionConnections = countDuplicatesInList(flagDict[region])
        for (connectedRegion, count) in regionConnections:
            if count >= minCount:
                outline = "%s\t%s\t%d" % (region, connectedRegion, count)
                print >> outfile, outline
                if doVerbose:
                    print outline

    outfile.close()
    if doVerbose:
        print "finished: ", time.ctime()


def processRDSFile(rdsfile, outbedname, minDist, maxDist, cachePages, label, doVerbose):
    doCache = False
    if cachePages is not None:
        doCache = True
    else:
        cachePages = 0

    if label is None:
        label = rdsfile

    RDS = ReadDataset.ReadDataset(rdsfile, verbose=True, cache=doCache)
    rdsChromList = RDS.getChromosomes()

    if doVerbose:
        print time.ctime()
    
    outbed = open(outbedname, "w")
    outbed.write('track name="%s distal pairs" color=0,255,0\n' % label)
    outbed.close()

    readlen = RDS.getReadSize()
    flagDict = {}
    for chromosome in rdsChromList:
        if doNotProcessChromosome(chromosome):
            continue

        writeFarPairs(flagDict, chromosome, RDS, readlen, outbedname, doVerbose, minDist, maxDist)

    print "%d connected regions" % len(flagDict)

    return flagDict


def doNotProcessChromosome(chrom):
    return chrom == "chrM"


def writeFarPairs(flagDict, chromosome, RDS, readlen, outbedname, doVerbose, minDist, maxDist):
    outbed = open(outbedname, "a")
    print chromosome
    uniqDict = RDS.getReadsDict(fullChrom=True, chrom=chromosome, noSense=True, withFlag=True, doUniqs=True, readIDDict=True)
    if doVerbose:
        print len(uniqDict), time.ctime()

    for readID in uniqDict:
        readList = uniqDict[readID]
        if readsAreFarPair(readList, minDist, maxDist):
            start1 = readList[0]["start"]
            start2 = readList[1]["start"]
            startList = [start1, start2]
            startList.sort()
            outputLine = splitReadWrite(chromosome, 2, startList, readlen, "+", readID, "0,255,0", "0,255,0")
            outbed.write(outputLine)
            flag1 = readList[0]["flag"]
            flag2 = readList[1]["flag"]
            if doVerbose:
                print flag1, flag2, abs(start1 - start2)

            try:
                flagDict[flag1].append(flag2)
            except KeyError:
                flagDict[flag1] = [flag2]

            try:
                flagDict[flag2].append(flag1)
            except KeyError:
                flagDict[flag2] = [flag1]

    outbed.close()


def readsAreFarPair(readList, minDist, maxDist):
    isFarPair = False
    if len(readList) == 2:
        flag1 = readList[0]["flag"]
        flag2 = readList[1]["flag"]
        if flag1 != flag2 and flag1 != "" and flag2 != "":
            start1 = readList[0]["start"]
            start2 = readList[1]["start"]
            dist = abs(start1 - start2)
            if minDist < dist < maxDist:
                isFarPair = True

    return isFarPair


def splitReadWrite(chrom, numPieces, startList, readlen, rsense, readName, plusSenseColor, minusSenseColor):
    sizes = ["%d" % readlen]
    coords = ["0"]
    leftStart = startList[0] - 1
    rightStop = startList[-1]
    for index in range(1, numPieces):
        sizes.append("%d" % (readlen + 1))
        coords.append("%d" % (startList[index] - startList[0]))

    if rsense == "+":
        senseCode = plusSenseColor
    else:
        senseCode = minusSenseColor

    readSizes = string.join(sizes, ",")
    readCoords = string.join(coords, ",")
    outline = "%s\t%d\t%d\t%s\t1000\t%s\t0\t0\t%s\t%d\t%s\t%s\n" % (chrom, leftStart, rightStop, readName, rsense, senseCode, numPieces, readSizes, readCoords)

    return outline


if __name__ == "__main__":
    main(sys.argv)