"""
MakeRdsFromBam

Created on Jun 3, 2010

@author: sau
"""

try:
    import psyco
    psyco.full()
except:
    pass

import sys
import string
import optparse
import re
import pysam
from commoncode import writeLog, getConfigParser, getConfigBoolOption, getConfigIntOption, getReverseComplement
import ReadDataset

INSERT_SIZE = 100000
verstring = "makeRdsFromBam: version 1.0"


def main(argv=None):
    if not argv:
        argv = sys.argv
    
    print verstring

    usage = "usage:  %prog label samfile outrdsfile [propertyName::propertyValue] [options]\
            \ninput reads must be sorted to properly record multireads"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    try:
        label = args[0]
    except IndexError:
        print "no label specified - see --help for usage"
        sys.exit(1)

    try:
        samFileName = args[1]
    except IndexError:
        print "no samfile specified - see --help for usage"
        sys.exit(1)

    try:
        outDbName = args[2]
    except IndexError:
        print "no outrdsfile specified - see --help for usage"
        sys.exit(1)

    if options.rnaDataType:
        dataType = "RNA"
    else:
        dataType = "DNA"

    print label
    print samFileName
    print outDbName
    print options.init
    print options.doIndex
    print options.useSamFile
    print options.cachePages
    print options.maxMultiReadCount
    print dataType
    print options.trimReadID

    makeRdsFromBam(label, samFileName, outDbName, options.init, options.doIndex, options.useSamFile,
                   options.cachePages, options.maxMultiReadCount, dataType, options.trimReadID)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--append", action="store_false", dest="init",
                      help="append to existing rds file [default: create new]")
    parser.add_option("--RNA", action="store_true", dest="rnaDataType",
                      help="set data type to RNA [default: DNA]")
    parser.add_option("-S", "--sam", action="store_true", dest="useSamFile",
                      help="input file is in sam format")
    parser.add_option("--index", action="store_true", dest="doIndex",
                      help="index the output rds file")
    parser.add_option("--cache", type="int", dest="cachePages",
                      help="number of cache pages to use [default: 100000")
    parser.add_option("-m", "--multiCount", type="int", dest="maxMultiReadCount",
                      help="multi counts over this value are discarded [default: 10]")
    parser.add_option("--rawreadID", action="store_false", dest="trimReadID",
                      help="use the raw read names")

    configParser = getConfigParser()
    section = "makeRdsFromBam"
    init = getConfigBoolOption(configParser, section, "init", True)
    doIndex = getConfigBoolOption(configParser, section, "doIndex", False)
    useSamFile = getConfigBoolOption(configParser, section, "useSamFile", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", 100000)
    maxMultiReadCount = getConfigIntOption(configParser, section, "maxMultiReadCount", 10)
    rnaDataType = getConfigBoolOption(configParser, section, "rnaDataType", False)
    trimReadID = getConfigBoolOption(configParser, section, "trimReadID", True)

    parser.set_defaults(init=init, doIndex=doIndex, useSamFile=useSamFile, cachePages=cachePages,
                        maxMultiReadCount=maxMultiReadCount, rnaDataType=rnaDataType, trimReadID=trimReadID)

    return parser


def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useSamFile=False,
                   cachePages=100000, maxMultiReadCount=10, dataType="DNA", trimReadID=True):

    if useSamFile:
        fileMode = "r"
    else:
        fileMode = "rb"

    writeLog("%s.log" % outDbName, verstring, string.join(sys.argv[1:]))
    rds = ReadDataset.ReadDataset(outDbName, init, dataType, verbose=True)
    if not init and doIndex:
        try:
            if rds.hasIndex():
                rds.dropIndex()
        except:
            pass

    if "sam_mapped" not in rds.getMetadata():
        rds.insertMetadata([("sam_mapped", "True")])

    defaultCacheSize = rds.getDefaultCacheSize()

    if cachePages > defaultCacheSize:
        if init:
            rds.setDBcache(cachePages, default=True)
        else:
            rds.setDBcache(cachePages)

    propertyList = []
    for arg in sys.argv:
        if "::" in arg:
            (pname, pvalue) = arg.strip().split("::")
            propertyList.append((pname, pvalue))

    if len(propertyList) > 0:
        rds.insertMetadata(propertyList)

    totalReadCounts = {"unmapped": 0,
                       "total": 0,
                       "unique": 0,
                       "multi": 0,
                       "multiDiscard": 0,
                       "splice": 0,
                       "multisplice": 0
    }

    readsize = 0

    uniqueInsertList = []
    multiInsertList = []
    spliceInsertList = []
    multispliceInsertList = []

    processedEntryDict = {}
    uniqueReadDict = {}
    multiReadDict = {}
    multispliceReadDict = {}
    spliceReadDict = {}
    multireadCounts = getMultiReadIDCounts(samFileName, fileMode)

    for readID in multireadCounts:
        if multireadCounts[readID] > maxMultiReadCount:
            totalReadCounts["multiDiscard"] += 1

    try:
        samfile = pysam.Samfile(samFileName, fileMode)
    except ValueError:
        print "samfile index not found"
        sys.exit(1)

    samFileIterator = samfile.fetch(until_eof=True)

    #TODO: here's the plan:
    #DONE1) go through the file and get a count of all readIDs + pairID
    #DONE2) use this to generate a list of the readIDs that are multireads
    #    3) go through again to process the reads
    #    4) each multi is a separate entry in the table (by alignment) but only counts as 1 for statistics purposes
    #DONE5) add a multisplice table to the RDS for the multisplices
    #    6) multisplices will look just like splices but have a weight of 1/N
    #    7) recall that flags can't be assumed so that we can't use the NH flag at all
    for read in samFileIterator:
        if read.is_unmapped:
            totalReadCounts["unmapped"] += 1
            continue

        if readsize == 0:
            take = (0, 1) # CIGAR operation (M/match, I/insertion)
            readsize = sum([length for op,length in read.cigar if op in take])
            if init:
                rds.insertMetadata([("readsize", readsize)])

        pairReadSuffix = getPairedReadNumberSuffix(read)
        readName = "%s%s" % (read.qname, pairReadSuffix)
        if trimReadID:
            rdsEntryName = "%s:%s:%d%s" % (label, read.qname, totalReadCounts["total"], pairReadSuffix)
        else:
            rdsEntryName = read.qname

        try:
            count = multireadCounts[readName]
        except KeyError:
            count = 1

        if count == 1:
            if isSpliceEntry(read.cigar):
                spliceReadDict[readName] = (read,rdsEntryName)
            else:
                uniqueReadDict[readName] = (read, rdsEntryName)
        elif count <= maxMultiReadCount:
            if isSpliceEntry(read.cigar):
                multispliceReadDict[readName] = (read, count, rdsEntryName)
            else:
                multiReadDict[readName] = (read, count, rdsEntryName)

        """
        if not processedEntryDict.has_key(readName):
            processedEntryDict[readName] = ""
            count = getReadCount(read)
            if isSpliceEntry(read.cigar):
                if count == 1:
                    spliceReadDict[readName] = (read,rdsEntryName)
            elif count == 1:
                uniqueReadDict[readName] = (read, rdsEntryName)
            else:
                multiReadDict[readName] = (read, count, rdsEntryName)

        """
        """
        if processedEntryDict.has_key(readName):
            if isSpliceEntry(read.cigar):
                if spliceReadDict.has_key(readName):
                    del spliceReadDict[readName]
            else:
                if uniqueReadDict.has_key(readName):
                    del uniqueReadDict[readName]

                if multiReadDict.has_key(readName):
                    (read, priorCount, rdsEntryName) = multiReadDict[readName]
                    count = priorCount + 1
                    multiReadDict[readName] = (read, count, rdsEntryName)
                else:
                    multiReadDict[readName] = (read, 1, rdsEntryName)
        else:
            processedEntryDict[readName] = ""
            if isSpliceEntry(read.cigar):
                spliceReadDict[readName] = (read,rdsEntryName)
            else:
                uniqueReadDict[readName] = (read, rdsEntryName)
        """

        if totalReadCounts["total"] % INSERT_SIZE == 0:
            for entry in uniqueReadDict.keys():
                (readData, rdsEntryName) = uniqueReadDict[entry]
                chrom = samfile.getrname(readData.rname)
                uniqueInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize))
                totalReadCounts["unique"] += 1

            for entry in multiReadDict.keys():
                (readData, count, rdsEntryName) = multiReadDict[entry]
                chrom = samfile.getrname(readData.rname)
                if count > maxMultiReadCount:
                    pass
                    #totalReadCounts["multiDiscard"] += 1
                else:
                    multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count)) 
                    #totalReadCounts["multi"] += 1

            if dataType == "RNA":
                for entry in spliceReadDict.keys():
                    (readData, rdsEntryName) = spliceReadDict[entry]
                    chrom = samfile.getrname(readData.rname)
                    spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
                    totalReadCounts["splice"] += 1

                for entry in multispliceReadDict.keys():
                    (readData, count, rdsEntryName) = multispliceReadDict[entry]
                    chrom = samfile.getrname(readData.rname)
                    multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
                    totalReadCounts["multisplice"] += 1

            rds.insertUniqs(uniqueInsertList)
            rds.insertMulti(multiInsertList)
            uniqueInsertList = []
            uniqueReadDict = {}
            multiInsertList = []
            multiReadDict = {}
            if dataType == "RNA":
                rds.insertSplices(spliceInsertList)
                spliceInsertList = []
                spliceReadDict = {}
                rds.insertMultisplices(multispliceInsertList)
                multispliceInsertList = []
                multispliceReadDict = {}

            print ".",
            sys.stdout.flush()

        totalReadCounts["total"] += 1

    if len(uniqueReadDict.keys()) > 0:
        for entry in uniqueReadDict.keys():
            (readData, rdsEntryName) = uniqueReadDict[entry]
            chrom = samfile.getrname(readData.rname)
            uniqueInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize))
            totalReadCounts["unique"] += 1

        rds.insertUniqs(uniqueInsertList)

    if len(multiReadDict.keys()) > 0:
        for entry in multiReadDict.keys():
            (readData, count, rdsEntryName) = multiReadDict[entry]
            chrom = samfile.getrname(readData.rname)
            if count > maxMultiReadCount:
                pass
                #totalReadCounts["multiDiscard"] += 1
            else:
                multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count))
                #totalReadCounts["multi"] += 1

        rds.insertMulti(multiInsertList)

    if len(spliceReadDict.keys()) > 0 and dataType == "RNA":
        for entry in spliceReadDict.keys():
            (readData, rdsEntryName) = spliceReadDict[entry]
            chrom = samfile.getrname(readData.rname)
            spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
            totalReadCounts["splice"] += 1

        rds.insertSplices(spliceInsertList)

    if len(multispliceReadDict.keys()) > 0 and dataType == "RNA":
        for entry in multispliceReadDict.keys():
            (readData, count, rdsEntryName) = multispliceReadDict[entry]
            chrom = samfile.getrname(readData.rname)
            multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
            totalReadCounts["multisplice"] += 1

        rds.insertMultisplices(multispliceInsertList)

    totalReadCounts["multi"] = len(multireadCounts) - totalReadCounts["multiDiscard"] - totalReadCounts["multisplice"]
    countStringList = ["\n%d unmapped reads discarded" % totalReadCounts["unmapped"]]
    countStringList.append("%d unique reads" % totalReadCounts["unique"])
    countStringList.append("%d multi reads" % totalReadCounts["multi"])
    countStringList.append("%d multi reads count > %d discarded" % (totalReadCounts["multiDiscard"], maxMultiReadCount))
    if dataType == "RNA":
        countStringList.append("%d spliced reads" % totalReadCounts["splice"])
        countStringList.append("%d spliced multireads" % totalReadCounts["multisplice"])

    print string.join(countStringList, "\n")
    outputCountText = string.join(countStringList, "\t")
    writeLog("%s.log" % outDbName, verstring, outputCountText)

    if doIndex:
        print "building index...."
        if cachePages > defaultCacheSize:
            rds.setDBcache(cachePages)
            rds.buildIndex(cachePages)
        else:
            rds.buildIndex(defaultCacheSize)


def getMultiReadIDCounts(samFileName, fileMode):
    try:
        samfile = pysam.Samfile(samFileName, fileMode)
    except ValueError:
        print "samfile index not found"
        sys.exit(1)

    readIDCounts = {}
    for read in samfile.fetch(until_eof=True):
        pairReadSuffix = getPairedReadNumberSuffix(read)
        readName = "%s%s" % (read.qname, pairReadSuffix)
        try:
            readIDCounts[readName] += 1
        except KeyError:
            readIDCounts[readName] = 1

    for readID in readIDCounts.keys():
        if readIDCounts[readID] == 1:
            del readIDCounts[readID]

    return readIDCounts


#TODO: axe this as we can't really use it
def getReadCount(read):
    for tag in read.tags:
        if tag[0] == "NH":
            return tag[1]


def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
    start = int(alignedRead.pos)
    stop = int(start + readSize)
    sense = getReadSense(alignedRead.is_reverse)
    try:
        mismatchTag = alignedRead.opt("MD")
        mismatches = getMismatches(mismatchTag, alignedRead.seq, sense)
    except KeyError:
        mismatches = ""

    return (readName, chrom, start, stop, sense, 1.0/weight, '', mismatches)


def getRDSSpliceEntry(alignedRead, readName, chrom, readSize, weight=1):
    (readName, chrom, start, stop, sense, weight, flag, mismatches) = getRDSEntry(alignedRead, readName, chrom, readSize, weight)
    startL, startR, stopL, stopR = getSpliceBounds(start, readSize, alignedRead.cigar)
    
    return (readName, chrom, startL, stopL, startR, stopR, sense, weight, "", mismatches)


def getPairedReadNumberSuffix(read):
    readSuffix = ""
    if not isPairedRead(read):
        return ""

    if read.is_read1:
        readSuffix = "/1"
    elif read.is_read2:
        readSuffix = "/2"

    return readSuffix


def isPairedRead(read):
    return read.is_proper_pair and (read.is_read1 or read.is_read2)


def isSpliceEntry(cigarTupleList):
    isSplice = False
    for operation,length in cigarTupleList:
        if operation == 3:
            isSplice = True
            break

    return isSplice


def getReadSense(reverse):
    if reverse:
        sense = "-"
    else:
        sense = "+"

    return sense


def getMismatches(mismatchTag, querySequence="", sense="+", logErrors=False):
    output = []
    deletionMarker = "^"
    position = 0

    lengths = re.findall("\d+", mismatchTag)
    mismatchSequences = re.findall("\d+([ACGTN]|\\^[ACGTN]+)", mismatchTag)

    for mismatchEntry in range(len(mismatchSequences)):
        mismatch = mismatchSequences[mismatchEntry]
        position = position + int(lengths[mismatchEntry])
        if string.find(mismatch, deletionMarker) == 0:
            continue

        try:
            if querySequence:
                genomicNucleotide = querySequence[position]
            else:
                genomicNucleotide = "N"

            if sense == "-":
                mismatch = getReverseComplement(mismatch)
                genomicNucleotide  = getReverseComplement(genomicNucleotide)

            erange1BasedElandCompatiblePosition = int(position + 1)
            output.append("%s%d%s" % (mismatch, erange1BasedElandCompatiblePosition, genomicNucleotide))
            position += 1
        except IndexError:
            if logErrors:
                errorMessage = "getMismatch IndexError; tag: %s, seq: %s, pos: %d" % (mismatchTag, querySequence, position)
                writeLog("MakeRdsFromBamError.log", "1.0", errorMessage)

            return ""

    return string.join(output, ",")


def getSpliceBounds(start, readsize, cigarTupleList):
    stopR = int(start + readsize)
    offset = 0

    for operation,length in cigarTupleList:
        if operation == 3:
            stopL = int(start + offset)
            startR = int(stopL + length)

            return start, startR, stopL, stopR
        else:
            offset += length


if __name__ == "__main__":
    main(sys.argv)