#
#  weightMultireads.py
#  ENRAGE
#

#  Created by Ali Mortazavi on 10/02/08.
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import time
import string
import optparse
import ReadDataset
from commoncode import getConfigParser, getConfigBoolOption, getConfigOption


print "weighMultireads: version 3.3"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %s rdsfile [--radius bp] [--noradius] [--usePairs maxDist] [--verbose] [--cache pages]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 1:
        print usage
        sys.exit(1)

    rdsfile = args[0]

    weighMultireads(rdsfile, options.radius, options.doRadius, options.pairDist, options.verbose, options.cachePages)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--radius", type="int", dest="radius")
    parser.add_option("--noradius", action="store_false", dest="doRadius")
    parser.add_option("--usePairs", type="int", dest="pairDist")
    parser.add_option("--verbose", action="store_true", dest="verbose")
    parser.add_option("--cache", type="int", dest="cachePages")

    configParser = getConfigParser()
    section = "weighMultireads"
    radius = getConfigOption(configParser, section, "radius", None)
    doRadius = getConfigBoolOption(configParser, section, "doRadius", True)
    pairDist = getConfigOption(configParser, section, "pairDist", None)
    verbose = getConfigBoolOption(configParser, section, "verbose", False)
    cachePages = getConfigOption(configParser, section, "cachePages", None)
    
    parser.set_defaults(radius=radius, doRadius=doRadius, pairDist=pairDist, verbose=verbose, cachePages=cachePages)

    return parser


def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):

    if cachePages is not None:
        doCache = True
    else:
        doCache = False
        cachePages = 1

    RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
    if cachePages > RDS.getDefaultCacheSize():
        RDS.setDBcache(cachePages)

    if verbose:
        print time.ctime()

    multiIDs = RDS.getReadIDs(uniqs=False, multi=True)
    if verbose:
        print "got multiIDs ", time.ctime()

    fixedReads = []
    if pairDist is not None:
        fixedReads = reweighUsingPairs(RDS, pairDist, multiIDs, verbose)

    if radius is not None:
        doRadius = True
    else:
        radius = 100

    if doRadius:
        reweighUsingRadius(RDS, radius, multiIDs, fixedReads, verbose)

    if doCache:
        RDS.saveCacheDB(rdsfile)

    if verbose:
        print "finished", time.ctime()


def reweighUsingPairs(RDS, pairDist, multiIDs, verbose=False):
    fixedPair = 0
    tooFar = pairDist * 10
    readlen = RDS.getReadSize()
    fixedReads = []
    print "doing pairs with pairDist = %d" % pairDist
    hasSplices = RDS.dataType == "RNA"
    uniqIDs = RDS.getReadIDs(uniqs=True, multi=False, splices=hasSplices)

    if verbose:
        print "got uniqIDs ", time.ctime()

    jointList, bothMultiList = getReadIDLists(uniqIDs, multiIDs, verbose)
    uniqDict = getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose)
    if verbose:
        print "guDict actual ", len(uniqDict), time.ctime()

    multiDict = getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose)
    if verbose:
        print "muDict actual ", len(multiDict), time.ctime()

    RDS.setSynchronousPragma("OFF")
    for readID in jointList:
        try:
            ustart = uniqDict[readID]["start"]
            ustop = ustart + readlen
        except KeyError:
            ustart = uniqDict[readID]["startL"]
            ustop = uniqDict[readID]["stopR"]

        uniqReadChrom = uniqDict[readID]["chrom"]
        multiReadList = multiDict[readID]
        numMultiReads = len(multiReadList)
        bestMatch = [tooFar] * numMultiReads
        found = False
        for index in range(numMultiReads):
            mstart = multiReadList[index]["start"]
            multiReadChrom = multiReadList[index]["chrom"]
            mpair = multiReadList[index]["pairID"]
            if uniqReadChrom != multiReadChrom:
                continue

            if abs(mstart - ustart) < pairDist:
                bestMatch[index] = abs(mstart - ustart)
                found = True
            elif abs(mstart - ustop) < pairDist:
                bestMatch[index] = abs(mstart - ustop)
                found = True

        if found:
            theMatch = -1
            theDist = tooFar
            reweighList = []
            for index in range(numMultiReads):
                if theDist > bestMatch[index]:
                    theMatch = index
                    theDist = bestMatch[index]

            theID = string.join([readID, mpair], "/")
            for index in range(numMultiReads):
                if index == theMatch:
                    score = 1 - ((numMultiReads - 1) / (100. * numMultiReads))
                else:
                    score = 1 / (100. * numMultiReads)

                start = multiReadList[index][0]
                chrom = "chr%s" % multiReadList[index][1]
                reweighList.append((round(score,3), chrom, start, theID))

            if theMatch > 0:
                RDS.reweighMultireads(reweighList)
                fixedPair += 1
                if verbose and fixedPair % 10000 == 1:
                    print "fixed %d" % fixedPair
                    print uniqDict[readID]
                    print multiDict[readID]
                    print reweighList

                fixedReads.append(theID)

    RDS.setSynchronousPragma("ON")

    print "fixed %d pairs" % fixedPair
    print time.ctime()

    return fixedReads


def getReadIDLists(uniqIDs, multiIDs, verbose=False):
    uidDict = {}
    mainIDList = []
    for readID in uniqIDs:
        (mainID, pairID) = readID.split("/")
        try:
            uidDict[mainID].append(pairID)
        except:
            uidDict[mainID] = [pairID]
            mainIDList.append(mainID)

    if verbose:
        print "uidDict all ", len(uidDict), time.ctime()

    for mainID in mainIDList:
        if len(uidDict[mainID]) == 2:
            del uidDict[mainID]

    if verbose:
        print "uidDict first candidates ", len(uidDict), time.ctime()

    midDict = {}
    for readID in multiIDs:
        (frontID, multiplicity) = readID.split("::")
        (mainID, pairID) = frontID.split("/")
        try:
            if pairID not in midDict[mainID]:
                midDict[mainID].append(pairID)
        except:
            midDict[mainID] = [pairID]

    if verbose:
        print "all multis ", len(midDict), time.ctime()

    mainIDList = uidDict.keys()
    for mainID in mainIDList:
        if mainID not in midDict:
            del uidDict[mainID]

    if verbose:
        print "uidDict actual candidates ", len(uidDict), time.ctime()

    jointList = []
    bothMultiList = []
    for readID in midDict:
        listLen = len(midDict[readID])
        if listLen == 1:
            if readID in uidDict:
                jointList.append(readID)
        elif listLen == 2:
            bothMultiList.append(readID)

    if verbose:
        print "joint ", len(jointList), time.ctime()
        print "bothMulti ", len(bothMultiList), time.ctime()

    return jointList, bothMultiList


def getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose=False):
    uniqReadsDict = {}
    uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
    if verbose:
        print "got uniq dict ", len(uniqDict), time.ctime()

    if RDS.dataType == "RNA":
        spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
        if verbose:
            print "got splice dict ", len(spliceDict), time.ctime()

    for readID in jointList:
        try:
            uniqReadsDict[readID] = uniqDict[readID][0]
        except KeyError:
            if RDS.dataType == "RNA":
                uniqReadsDict[readID] = spliceDict[readID][0]

    return uniqReadsDict


def getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose=False):
    multiReadSubsetDict = {}
    multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
    if verbose:
        print "got multi dict ", len(multiDict), time.ctime()

    for readID in jointList:
        multiReadSubsetDict[readID] = multiDict[readID]

    for readID in bothMultiList:
        multiReadSubsetDict[readID] = multiDict[readID]

    return multiReadSubsetDict


def reweighUsingRadius(RDS, radius, multiIDs, readsToSkip=[], verbose=False):
    skippedReads = 0
    readlen = RDS.getReadSize()
    halfreadlen = readlen / 2
    print "doing uniq read radius with radius = %d" % radius
    multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
    print "got multiDict"
    RDS.setSynchronousPragma("OFF")
    reweighedCount = 0
    for readID in multiIDs:
        originalMultiReadID = readID
        if originalMultiReadID in readsToSkip:
            skippedReads += 1
            continue

        if "::" in readID:
            (readID, multiplicity) = readID.split("::")

        scores = []
        coords = []
        for read in multiDict[readID]:
            start = read["start"]
            chromosome = "chr%s" % read["chrom"]
            regionStart = start + halfreadlen - radius
            regionStop = start + halfreadlen + radius 
            uniqs = RDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
            scores.append(uniqs + 1)
            coords.append((chromosome, start, originalMultiReadID))

        total = float(sum(scores))
        reweighList = []
        for index in range(len(scores)):
            reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))

        RDS.reweighMultireads(reweighList)
        reweighedCount += 1
        if reweighedCount % 10000 == 0:
            print reweighedCount

    RDS.setSynchronousPragma("ON")
    if verbose:
        print "skipped ", skippedReads

    print "reweighted ", reweighedCount


if __name__ == "__main__":
    main(sys.argv)
