#
#  makebedfromrds.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 7/19/08.
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import optparse
import ReadDataset
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption

PLUS_COLOR = "0,0,255"
MINUS_COLOR = "255,0,0"
MULTI_PLUS_COLOR = "64,64,64"
MULTI_MINUS_COLOR = "192,192,192"
SPLICE_COLOR = "255,0,0"
UNIQUE_COLOR = "0,0,0"
MULTI_COLOR = "128,128,128"


def main(argv=None):
    if not argv:
        argv = sys.argv

    verstring = "makebedfromrds: version 3.2"
    print verstring

    doPairs = False
    
    usage = "usage:  %prog trackLabel rdsFile bamFile [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    try:
        trackType = args[0]
    except IndexError:
        print "no track specified - see --help for usage"
        sys.exit(1)

    try:
        rdsfile = args[1]
    except IndexError:
        print "no RDS file specified - see --help for usage"
        sys.exit(1)

    try:
        outfilename = args[2]
    except IndexError:
        print "no output file specified - see --help for usage"
        sys.exit(1)

    if options.pairDist is not None:
        doPairs = True

    if options.chromList:
        options.allChrom = False

    outputBedFromRds(trackType, rdsfile, outfilename, options.withUniqs, options.withMulti,
                     options.doSplices, options.doSpliceColor, doPairs, options.pairDist,
                     options.withFlag, options.useFlagLike, options.enforceChr, options.senseStrand,
                     options.allChrom, options.doCache, options.cachePages, options.chromList)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--nouniq", action="store_false", dest="withUniqs")
    parser.add_option("--nomulti", action="store_false", dest="withMulti")
    parser.add_option("--splices", action="store_true", dest="doSplices")
    parser.add_option("--spliceColor", action="store_true", dest="doSpliceColor")
    parser.add_option("--flag", dest="withFlag")
    parser.add_option("--flaglike", action="store_true", dest="useFlagLike")
    parser.add_option("--pairs", type="int", dest="pairDist")
    parser.add_option("--cache", type="int", dest="cachePages")
    parser.add_option("--enforceChr", action="store_true", dest="enforceChr")
    parser.add_option("--chrom", action="append", dest="chromList")
    parser.add_option("--strand", dest="strand")

    configParser = getConfigParser()
    section = "makebedfromrds"
    withUniqs = getConfigBoolOption(configParser, section, "withUniqs", True)
    withMulti = getConfigBoolOption(configParser, section, "withMulti", False)
    doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
    doSpliceColor = getConfigBoolOption(configParser, section, "doSpliceColor", False)
    pairDist = getConfigOption(configParser, section, "pairDist", None)
    withFlag = getConfigOption(configParser, section, "withFlag", "")
    useFlagLike = getConfigBoolOption(configParser, section, "useFlagLike", False)
    enforceChr = getConfigBoolOption(configParser, section, "enforceChr", False)
    senseStrand = getConfigOption(configParser, section, "senseStrand", "")
    allChrom = getConfigBoolOption(configParser, section, "allChrom", True)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    cachePages = getConfigOption(configParser, section, "cachePages", 100000)

    parser.set_defaults(withUniqs=withUniqs, withMulti=withMulti, doSplices=doSplices, doSpliceColor=doSpliceColor,
                        pairDist=pairDist, withFlag=withFlag, useFlagLike=useFlagLike, enforceChr=enforceChr,
                        senseStrand=senseStrand, allChrom=allChrom, doCache=doCache, cachePages=cachePages,
                        chromList=[])

    return parser


def outputBedFromRds(trackType, rdsfile, outfilename, withUniqs=True, withMulti=True,
                     doSplices=False, doSpliceColor=False, doPairs=False, pairDist=1000000,
                     withFlag="", useFlagLike=False, enforceChr=False, senseStrand="",
                     allChrom=True, doCache=False, cachePages=100000, chromList=[]):

    if not withUniqs and not withMulti and not doSplices:
        print "must be outputing at least one of uniqs, multi, or -splices - exiting"
        sys.exit(1)

    print "\nsample:"
    RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)

    #check that this is better than the dataset's default cache size
    if cachePages > RDS.getDefaultCacheSize():
        RDS.setDBcache(cachePages)

    readlength = RDS.getReadSize()
    minDist = -1 * readlength

    if allChrom:
        if withUniqs:
            chromList = RDS.getChromosomes()
        elif withMulti:
            chromList = RDS.getChromosomes(table="multi")
        else:
            chromList = RDS.getChromosomes(table="splices")

        chromList.sort()

    outfile = open(outfilename, "w")
    outfile.write('track name="%s" visibility=4 itemRgb="On"\n' % (trackType))

    if withUniqs or withMulti:
        for achrom in chromList:
            index = 0
            if doNotOutputChromosome(achrom, enforceChr):
                continue

            print "chromosome %s" % (achrom)

            if doPairs:
                hitDict = RDS.getReadsDict(fullChrom=True, chrom=achrom, flag=withFlag,
                                           withWeight=True, withPairID=True, doUniqs=withUniqs,
                                           doMulti=withMulti, readIDDict=True,
                                           flagLike=useFlagLike, strand=senseStrand)

                readIDList = hitDict.keys()
                if doSplices:
                    spliceDict = RDS.getSplicesDict(fullChrom=True, chrom=achrom, flag=withFlag,
                                                    withPairID=True, readIDDict=True,
                                                    flagLike=useFlagLike, strand=senseStrand)

                    spliceIDList = spliceDict.keys()
                    combDict = {}
                    for readID in readIDList:
                        combDict[readID] = 1

                    for readID in spliceIDList:
                        combDict[readID] = 1

                    combinedIDList = combDict.keys()
                else:
                    combinedIDList = readIDList

                for readID in combinedIDList:
                    localList = []
                    try:
                        localList = hitDict[readID]
                    except:
                        pass

                    if doSplices:
                        try:
                            localList += spliceDict[readID]
                        except:
                            pass

                    localList.sort()
                    listLen = len(localList) - 1
                    localIndex = 0
                    while localIndex <= listLen:
                        read = localList[localIndex]
                        try:
                            leftpos = read["start"]
                            leftsense = read["sense"]
                            leftweight = read["weight"]
                            lPairID = read["pairID"]
                            leftstop = leftpos + readlength - 1
                            lpart = 1
                            startList = [leftpos]
                            stopList = [leftstop]
                        except KeyError:
                            leftpos = read["startL"]
                            LLstop = read["stopL"]
                            LRstart = read["startR"]
                            leftstop = read["stopL"]
                            leftsense = read["sense"]
                            lPairID = read["pairID"]
                            leftweight = 1.0
                            lpart = 2
                            startList = [leftpos, LRstart]
                            stopList = [LLstop, leftstop]

                        if localIndex < listLen:
                            read = localList[localIndex + 1]
                            try:
                                rightpos = read["start"]
                                rightsense = read["sense"]
                                rightweight = read["weight"]
                                rPairID= read["pairID"]
                                rightstop = rightpos + readlength - 1
                                rpart = 1
                                rstartList = [rightpos]
                                rstopList = [rightstop]
                            except KeyError:
                                rightpos = read["startL"]
                                RLstop = read["stopL"]
                                RRstart = read["startR"]
                                rightstop = read["stopR"]
                                rightsense = read["sense"]
                                rPairID = read["pairID"]
                                rightweight = 1.0
                                rpart = 2
                                rstartList = [rightpos, RRstart]
                                rstopList = [RLstop, rightstop]
                        else:
                            rightsense = "+"
                            rightpos = 0
                            rstartList = []
                            rstopList = []

                        if leftsense == "+" and rightsense == "-" and minDist < (rightpos - leftstop) < pairDist and lPairID != rPairID:
                            if doSpliceColor:
                                plusSenseColor, minusSenseColor = getSpliceColor(lpart, rpart, leftweight, rightweight, hackType="1")
                            elif leftweight == 1.0 or rightweight == 1.0:
                                plusSenseColor = "0,0,0"
                                minusSenseColor = MINUS_COLOR
                            else:
                                plusSenseColor = "128,128,128"
                                minusSenseColor = MULTI_MINUS_COLOR

                            splitReadWrite(outfile, achrom, lpart + rpart, startList + rstartList, stopList + rstopList, "+", readID, plusSenseColor, minusSenseColor)
                            localIndex += 2
                            index += 2
                        else:
                            if doSpliceColor:
                                plusSenseColor, minusSenseColor = getSpliceColor(lpart, rpart, leftweight, rightweight)
                                outputSense = "+"
                            elif leftweight == 1.0:
                                plusSenseColor = PLUS_COLOR
                                minusSenseColor = MINUS_COLOR
                                outputSense = leftsense
                            else:
                                plusSenseColor = PLUS_COLOR
                                minusSenseColor = MINUS_COLOR
                                outputSense = leftsense

                            splitReadWrite(outfile, achrom, lpart, startList, stopList, outputSense, readID, plusSenseColor, minusSenseColor)
                            localIndex += 1
                            index += 1
            else:
                hitDict = RDS.getReadsDict(fullChrom=True, chrom=achrom, flag=withFlag, withWeight=True, withID=True, doUniqs=withUniqs, doMulti=withMulti, readIDDict=False, flagLike=useFlagLike)
                try:
                    for read in hitDict[achrom]:
                        pos = read["start"]
                        sense = read["sense"]
                        readID = read["readID"]
                        splitReadWrite(outfile, achrom, 1, [pos], [pos + readlength - 1], sense, readID, PLUS_COLOR, MINUS_COLOR)
                        index += 1
                except:
                    pass

                if doSplices:
                    spliceDict = RDS.getSplicesDict(fullChrom=True, chrom=achrom, flag=withFlag, withID=True, flagLike=useFlagLike)
                    if achrom not in spliceDict:
                        continue
                    for read in spliceDict[achrom]:
                        readstart = read["startL"]
                        Lstop = read["stopL"]
                        Rstart = read["startR"]
                        readstop = read["stopR"]
                        rsense = read["sense"]
                        readName = read["readID"]
                        splitReadWrite(outfile, achrom, 2, [readstart, Rstart], [Lstop, readstop], rsense, readName, PLUS_COLOR, MINUS_COLOR)
                        index += 1

    elif doSplices:
        for achrom in chromList:
            index = 0
            if doNotOutputChromosome(achrom, enforceChr):
                continue

            print "chromosome %s" % (achrom)

            spliceDict = RDS.getSplicesDict(fullChrom=True, chrom=achrom, flag=withFlag, withID=True, flagLike=useFlagLike)
            if achrom not in spliceDict:
                continue
            for read in spliceDict[achrom]:
                readstart = read["startL"]
                Lstop = read["stopL"]
                Rstart = read["startR"]
                readstop = read["stopR"]
                rsense = read["sense"]
                readName = read["readID"]
                splitReadWrite(outfile, achrom, 2, [readstart, Rstart], [Lstop, readstop], rsense, readName, PLUS_COLOR, MINUS_COLOR)
                index += 1

        print index

    outfile.close()


def singleReadWrite(chrom, pos, sense, weight, readID, readlength, outfile):
    start = pos
    stop = pos + readlength - 1
    senseColor = getSenseColor(sense, weight)
    outfile.write("%s %d %d %s %.1f %s 0 0 %s\n" % (chrom, start, stop, readID, weight, sense, senseColor))


def getSenseColor(sense, weight):
    if weight < 1.0:
        senseColor = getMultiSenseColor(sense)
    else:
        senseColor = getSingleSenseColor(sense)

    return senseColor


def getMultiSenseColor(sense):
    if sense == "+":
        senseColor = MULTI_PLUS_COLOR
    else:
        senseColor = MULTI_MINUS_COLOR

    return senseColor


def getSingleSenseColor(sense):
    if sense == "+":
        senseColor = PLUS_COLOR
    else:
        senseColor = MINUS_COLOR

    return senseColor


def splitReadWrite(outfile, chrom, numPieces, startList, stopList, rsense, readName, plusSense, minusSense):
    readSizes = getReadSizes(numPieces, startList, stopList)
    readCoords = getReadCoords(numPieces, startList)
    leftStart = startList[0]
    rightStop = stopList[-1]

    if rsense == "+":
        senseCode = plusSense
    else:
        senseCode = minusSense
    
    outline = "%s\t%d\t%d\t%s\t1000\t%s\t0\t0\t%s\t%d\t%s\t%s\n" % (chrom, leftStart, rightStop, readName, rsense, senseCode, numPieces, readSizes, readCoords)
    outfile.write(outline)


def getReadSizes(numPieces, startList, stopList):
    readSizes = "%d" % (stopList[0] - startList[0])
    for index in range(1, numPieces):
        readSizes += ',%d' % (stopList[index] - startList[index])

    return readSizes


def getReadCoords(numPieces, startList):
    readCoords = "0"
    for index in range(1, numPieces):
        readCoords += ",%d" % (startList[index] - startList[0])

    return readCoords


def getSpliceColor(lpart, rpart, leftweight, rightweight, hackType=None):
    if hackType == "1":
        if (lpart + rpart) > 2:
            aColor = SPLICE_COLOR
            bColor = SPLICE_COLOR
        elif leftweight == 1.0 or rightweight == 1.0:
            aColor = UNIQUE_COLOR
            bColor = UNIQUE_COLOR
        else:
            aColor = MULTI_COLOR
            bColor = MULTI_COLOR
    else:
        if lpart  > 1:
            aColor = SPLICE_COLOR
            bColor = SPLICE_COLOR
        elif leftweight == 1.0:
            aColor = UNIQUE_COLOR
            bColor = UNIQUE_COLOR
        else:
            aColor = MULTI_COLOR
            bColor = MULTI_COLOR

    return aColor, bColor


def doNotOutputChromosome(achrom, enforceChr):
    result = False

    if achrom == "chrM":
        result = True

    if enforceChr and ("chr" not in achrom):
        result = True

    return result


if __name__ == "__main__":
    main(sys.argv)