#
#  makerdsfrombowtie.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 10/20/08.
#

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import string
import optparse
from commoncode import writeLog, getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption
import ReadDataset

verstring = "makerdsfrombowtie: version 4.2"
print verstring

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog label infilename outrdsfile [propertyName::propertyValue] [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    label = args[0]
    filename = args[1]
    outdbname = args[2]

    propertyList = []
    for arg in args:
        if "::" in arg:
            (pname, pvalue) = arg.strip().split("::")
            propertyList.append((pname, pvalue))

    makerdsfrombowtie(label, filename, outdbname, options.genedatafilename, options.init,
                      options.doIndex, options.spacer, options.trimReadID, options.forceID,
                      options.flip, options.verbose, options.stripSpace, options.cachePages,
                      propertyList)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--RNA", dest="genedatafilename")
    parser.add_option("--append", action="store_false", dest="init")
    parser.add_option("--index", action="store_true", dest="doIndex")
    parser.add_option("--spacer", type="int", dest="spacer")
    parser.add_option("--rawreadID", action="store_false", dest="trimReadID")
    parser.add_option("--forcepair", type="int", dest="forceID")
    parser.add_option("--flip", action="store_true", dest="flip")
    parser.add_option("--verbose", action="store_true", dest="verbose")
    parser.add_option("--strip", action="store_true", dest="stripSpace")
    parser.add_option("--cache", type="int", dest="cachePages")

    configParser = getConfigParser()
    section = "makerdsfrom bowtie"
    genedatafilename = getConfigOption(configParser, section, "genedatafilename", None)
    init = getConfigBoolOption(configParser, section, "init", True)
    doIndex = getConfigBoolOption(configParser, section, "doIndex", False)
    spacer = getConfigIntOption(configParser, section, "spacer", 2)
    trimReadID = getConfigBoolOption(configParser, section, "trimReadID", True)
    forceID = getConfigOption(configParser, section, "forceID", None)
    flip = getConfigBoolOption(configParser, section, "flip", False)
    verbose = getConfigBoolOption(configParser, section, "verbose", False)
    stripSpace = getConfigBoolOption(configParser, section, "stripSpace", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", 100000)

    parser.set_defaults(genedatafilename=genedatafilename, init=init, doIndex=doIndex, spacer=spacer,
                        trimReadID=trimReadID, forceID=forceID, flip=flip, verbose=verbose,
                        stripSpace=stripSpace, cachePages=cachePages)

    return parser


def makerdsfrombowtie(label, filename, outdbname, genedatafilename=None, init=True,
                      doIndex=False, spacer=2, trimReadID=True, forceID=None,
                      flip=False, verbose=False, stripSpace=False, cachePages=100000,
                      propertyList=[]):

    writeLog("%s.log" % outdbname, verstring, string.join(sys.argv[1:]))

    geneDict = {}
    dataType = "DNA"
    if genedatafilename is not None:
        dataType = "RNA"
        genedatafile = open(genedatafilename)
        for line in genedatafile:
            fields = line.strip().split("\t")
            blockCount = int(fields[7])
            if blockCount < 2:
                continue

            uname = fields[0]
            chrom = fields[1]
            sense = fields[2]
            chromstarts = fields[8][:-1].split(",")
            chromstops = fields[9][:-1].split(",")
            exonLengths = []
            totalLength = 0
            for index in range(blockCount):
                chromstarts[index] = int(chromstarts[index])
                chromstops[index] = int(chromstops[index])
                exonLengths.append(chromstops[index] - chromstarts[index])
                totalLength += exonLengths[index]

            geneDict[uname] = (sense, blockCount, totalLength, chrom, chromstarts, exonLengths)

        genedatafile.close()

    rds = ReadDataset.ReadDataset(outdbname, init, dataType, verbose=True)

    #check that our cacheSize is better than the dataset's default cache size
    defaultCacheSize = rds.getDefaultCacheSize()
    if cachePages > defaultCacheSize:
        if init:
            rds.setDBcache(cachePages, default=True)
        else:
            rds.setDBcache(cachePages)

    if not init and doIndex:
        try:
            if rds.hasIndex():
                rds.dropIndex()
        except:
            if verbose:
                print "couldn't drop Index"

    if len(propertyList) > 0:
        rds.insertMetadata(propertyList)

    # make some assumptions based on first read
    infile = open(filename, "r")
    line = infile.readline()
    if stripSpace:
        line = line.replace(" ","")

    fields = line.split()
    readsize = len(fields[5])
    pairedTest = fields[0][-2:]
    forcePair = False
    if forceID is not None:
        forcePair = True
    else:
        forceID = 0

    paired = False
    if pairedTest in ["/1", "/2"] or forcePair:
        print "assuming reads are paired"
        paired = True

    print "read size: %d bp" % readsize
    if init:
        rds.insertMetadata([("readsize", readsize)])
        if paired:
            rds.insertMetadata([("paired", "True")])

    if "bowtie_mapped" not in rds.getMetadata():
        rds.insertMetadata([("bowtie_mapped", "True")])

    if dataType == "RNA" and "spacer" not in rds.getMetadata():
        rds.insertMetadata([("spacer", spacer)])

    infile.close()

    maxBorder = 0
    if dataType == "RNA":
        trim = -4
        maxBorder = readsize + trim

    infile = open(filename, "r")
    prevID = ""
    readList = []
    uInsertList = []
    mInsertList = []
    sInsertList = []
    index = uIndex = mIndex = sIndex = lIndex = 0
    delimiter = "|"
    insertSize = 100000
    for line in infile:
        lIndex += 1
        if stripSpace:
            line = line.replace(" ","")

        fields = line.strip().split()
        readID = fields[0]
        if trimReadID:
            readID = string.join(readID.split(":")[1:], ":")

        if readID != prevID:
            listlen = len(readList)
            if trimReadID:
                prevID = "%s-%s" % (label, prevID)

            if forcePair:
                prevID += "/%d" % forceID 

            if listlen == 1:
                (sense, chrom, start, mismatches) = readList[0]
                if flip:
                    if sense == "+":
                        sense = "-"
                    else:
                        sense = "+"

                if "|" not in chrom:
                    stop = start + readsize
                    uInsertList.append((prevID, chrom, start, stop, sense, 1.0, "", mismatches))
                    uIndex += 1
                elif dataType == "RNA":
                    currentSplice = chrom
                    (model, spliceID, regionStart) = currentSplice.split(delimiter)
                    if model not in geneDict:
                        prevID = readID
                    else:
                        (gsense, blockCount, transLength, chrom, chromstarts, blockSizes) = geneDict[model]
                        spliceID = int(spliceID)
                        rstart = int(start) - spacer
                        lefthalf = maxBorder - rstart
                        if lefthalf < 1 or lefthalf > maxBorder:
                            prevID = readID
                        else:
                            righthalf = readsize - lefthalf
                            startL = int(regionStart)  + rstart
                            stopL = startL + lefthalf
                            startR = chromstarts[spliceID + 1]
                            stopR = chromstarts[spliceID + 1] + righthalf
                            sInsertList.append((prevID, chrom, startL, stopL, startR, stopR, sense, 1.0, "", mismatches))
                            sIndex += 1
            elif listlen > 1:
                prevID = "%s::%s" % (prevID, str(listlen))
                mIndex += 1
                # ignore multireads that can also map across splices
                skip = False
                for (sense, chrom, start, mismatches) in readList:
                    if "|" in chrom:
                        skip = True

                if not skip:
                    for (sense, chrom, start, mismatches) in readList:
                        stop = start + readsize
                        if flip:
                            if sense == "+":
                                sense = "-"
                            else:
                                sense = "+"

                        mInsertList.append((prevID, chrom, start, stop, sense, 1.0 / listlen, "", mismatches))
            else:
                prevID = readID

            if index % insertSize == 0:
                rds.insertUniqs(uInsertList)
                rds.insertMulti(mInsertList)
                uInsertList = []
                mInsertList = []
                if dataType == "RNA":
                    rds.insertSplices(sInsertList)
                    sInsertList = []

                print ".",
                sys.stdout.flush()

            # start processing new read
            readList = []
            prevID = readID
            index += 1

        # add the new read
        sense = fields[1]
        chrom = fields[2]
        # for eland compat, we are 1-based
        start = int(fields[3]) + 1
        mismatches = ""
        if ":" in fields[-1]:
            mismatches = decodeMismatches(fields[-1], sense)

        readList.append((sense, chrom, start, mismatches))
        if lIndex % 1000000 == 0:
            print "processed %d lines" % lIndex

    print "%d lines processed" % lIndex

    if len(uInsertList) > 0:
        rds.insertUniqs(uInsertList)

    if len(mInsertList) > 0:
        rds.insertMulti(mInsertList)

    if len(sInsertList) > 0:
        rds.insertSplices(sInsertList)

    combString = "%d unique reads" % uIndex
    combString += "\t%d multi reads" % mIndex
    if dataType == "RNA":
        combString += "\t%d spliced reads" % sIndex

    print
    print combString.replace("\t", "\n")

    writeLog("%s.log" % outdbname, verstring, combString)

    if doIndex:
        print "building index...."
        if cachePages > defaultCacheSize:
            rds.setDBcache(cachePages)
            rds.buildIndex(cachePages)
        else:
            rds.buildIndex(defaultCacheSize)


def decodeMismatches(mString, rsense):
    complement = {"A": "T",
                  "T": "A",
                  "C": "G",
                  "G": "C",
                  "N": "N"
    }

    output = []
    mismatches = mString.split(",")
    for mismatch in mismatches:
        (pos,change) = mismatch.split(":")
        (genNT, readNT) = change.split(">")
        if rsense == "-":
            readNT = complement[readNT]
            genNT  = complement[genNT]

        elandCompatiblePos = int(pos) + 1
        output.append("%s%d%s" % (readNT, elandCompatiblePos, genNT))

    return string.join(output, ",")


if __name__ == "__main__":
    main(sys.argv)