#
#  RNAFARpairs.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 11/2/08.
#
""" usage: python rnafarpairs.py genome goodfile rdsfile outfile [options]
           looks at all chromosomes simultaneously: is both slow and takes up large amount of RAM
"""
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import time
import optparse
import ReadDataset
from commoncode import getGeneInfoDict, getGeneAnnotDict, getConfigParser, getConfigIntOption, getConfigBoolOption


def main(argv=None):
    if not argv:
        argv = sys.argv

    print "rnafarPairs: version 3.7"
    usage = "usage: python %prog genome goodfile rdsfile outfile [options]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 4:
        print usage
        sys.exit(1)

    genome = args[0]
    goodfilename = args[1]
    rdsfile = args[2]
    outfilename = args[3]

    rnaFarPairs(genome, goodfilename, rdsfile, outfilename, options.doVerbose, options.doCache, options.maxDist)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--verbose", action="store_true", dest="doVerbose",
                      help="verbose output")
    parser.add_option("--cache", action="store_true", dest="doCache",
                      help="use cache")
    parser.add_option("--maxDist", type="int", dest="maxDist",
                      help="maximum distance")

    configParser = getConfigParser()
    section = "rnafarPairs"
    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    maxDist = getConfigIntOption(configParser, section, "maxDist", 500000)

    parser.set_defaults(doVerbose=doVerbose, doCache=doCache, maxDist=maxDist)

    return parser


def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
    goodDict = {}
    goodfile = open(goodfilename)
    for line in goodfile:
        fields = line.split()
        goodDict[fields[0]] = line

    goodfile.close()
    RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
    chromosomeList = RDS.getChromosomes()
    if doVerbose:
        print time.ctime()

    distinct = 0
    total = 0
    outfile = open(outfilename,"w")
    geneinfoDict = getGeneInfoDict(genome)
    geneannotDict = getGeneAnnotDict(genome)
    assigned = {}
    farConnected = {}
    for chromosome in chromosomeList:
        if doNotProcessChromosome(chromosome):
            continue

        print chromosome
        uniqDict = RDS.getReadsDict(fullChrom=True, chrom=chromosome, noSense=True, withFlag=True, doUniqs=True, readIDDict=True)
        if doVerbose:
            print len(uniqDict), time.ctime()    

        for readID in uniqDict:
            readList = uniqDict[readID]
            if len(readList) == 2:
                total += 1
                if processReads(readList[:2], maxDist):
                    flags = (readList[0]["flag"], readList[1]["flag"])
                    processed, distinctPairs = writeFarPairsToFile(flags, goodDict, genome, geneinfoDict, geneannotDict, outfile, assigned, farConnected)
                    total += processed
                    distinct += distinctPairs

    entriesWritten = writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile)
    distinct += entriesWritten
    outfile.write("#distinct: %d\ttotal: %d\n" % (distinct, total))
    outfile.close()
    print "distinct: %d\ttotal: %d" % (distinct, total)
    print time.ctime()


def doNotProcessChromosome(chromosome):
    return chromosome == "chrM"


def processReads(reads, maxDist):
    process = False
    start1 = reads[0]["start"]
    start2 = reads[1]["start"]
    dist = abs(start1 - start2)
    flag1 = reads[0]["flag"]
    flag2 = reads[1]["flag"]

    if flag1 != flag2 and flag1 != "NM" and flag2 != "NM" and dist < maxDist:
        process = True

    return process


def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, outfile, assigned, farConnected):
    flag1, flag2 = flags
    total = 0
    distinct = 0
    read1IsGood = flag1 in goodDict
    read2IsGood = flag2 in goodDict

    if read1IsGood and read2IsGood:
        if flag1 < flag2:
            geneID = flag1
            farFlag = flag2
        else:
            geneID = flag2
            farFlag = flag1

        try:
            farConnected[geneID].append(farFlag)
        except KeyError:
            farConnected[geneID] = [farFlag]
    elif read1IsGood or read2IsGood:
        total += 1
        if read2IsGood:
            farFlag = flag2
            geneID = flag1
        else:
            farFlag = flag1
            geneID = flag2

        try:
            if genome == "dmelanogaster":
                symbol = geneInfoDict["Dmel_%s" % geneID][0][0]
            else:
                symbol = geneInfoDict[geneID][0][0]
        except (KeyError, IndexError):
            try:
                symbol = geneAnnotDict[(genome, geneID)][0]
            except (KeyError, IndexError):
                symbol = "LOC%s" % geneID

        symbol = symbol.strip()
        symbol = symbol.replace(" ","|")
        symbol = symbol.replace("\t","|")

        if farFlag not in assigned:
            assigned[farFlag] = (symbol, geneID)
            print "%s %s %s" % (symbol, geneID, goodDict[farFlag].strip())
            outfile.write("%s %s %s" % (symbol, geneID, goodDict[farFlag]))
            distinct += 1

    return total, distinct


def writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile):
    total, written = writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile)
    writeUnassignedGoodReadsToFile(total, goodDict, assigned, outfile)

    return written


def writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile):
    total = 0
    written = 0
    for farFlag in farConnected:
        geneID = ""
        symbol = ""
        idList = [farFlag] + farConnected[farFlag]
        for ID in idList:
            if ID in assigned:
                (symbol, geneID) = assigned[ID]

        if geneID == "":
            total += 1
            symbol = "FAR%d" % total
            geneID = -1 * total

        for ID in idList:
            if ID not in assigned:
                print "%s %s %s" % (symbol, geneID, goodDict[ID].strip())
                outfile.write("%s %s %s" % (symbol, geneID, goodDict[ID]))
                written += 1
                assigned[ID] = (symbol, geneID)

    return total, written


def writeUnassignedGoodReadsToFile(farIndex, goodDict, assigned, outfile):
    for farFlag in goodDict:
        if farFlag not in assigned:
            farIndex += 1
            line = "FAR%d %d %s" % (farIndex, -1 * farIndex, goodDict[farFlag])
            print line.strip()
            outfile.write(line)


if __name__ == "__main__":
    main(sys.argv)