#
#  makerdsfromeland2.py
#  ENRAGE
#
try:
    import psyco
    psyco.full()
except:
    pass

import sys
import string
import optparse
import ReadDataset
from commoncode import getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption

def main(argv=None):
    if not argv:
        argv = sys.argv

    verstring = "makerdsfromeland2: version 3.5"
    print verstring

    usage = "usage:  %prog label infilename outrdsfile [propertyName::propertyValue] [options]\
            \ninput reads must be sorted to properly record multireads"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    label = args[0]
    filename = args[1]
    outdbname = args[2]

    delimiter = '|'
    if options.useOldDelimiter:
        delimiter = ':'

    paired = False
    pairID = '1'
    if options.pairID is not None:
        paired = True
        if options.pairID not in ['1','2']:
            print 'pairID value must be 1 or 2'
            sys.exit(-1)

        print 'Treating read IDs as paired with label = %s and pairID = %s' % (label, pairID)

    dataType = 'DNA'
    if options.geneDataFileName is not None:
        dataType = 'RNA'

    makeRDSFromEland2(label, filename, outdbname, options.doIndex, delimiter, paired, options.init,
                      options.pairID, dataType, options.geneDataFileName, options.cachePages,
                      options.maxLines, options.extended, options.verbose)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--append", action="store_false", dest="init",
                      help="append to existing rds file [default: create new]")
    parser.add_option("--RNA", dest="geneDataFileName",
                      help="set data type to RNA [default: DNA]")
    parser.add_option("--index", action="store_true", dest="doIndex",
                      help="index the output rds file")
    parser.add_option("--cache", type="int", dest="cachePages",
                      help="number of cache pages to use [default: 100000")
    parser.add_option("--olddelimiter", action="store_true", dest="useOldDelimiter",
                      help="use : as the delimiter")
    parser.add_option("--paired", dest="pairID",
                      help="pairID value")
    parser.add_option("--extended", action="store_true", dest="extended",
                      help="use eland_extended input")
    parser.add_option("--verbose", action="store_true", dest="verbose")
    parser.add_option("--maxlines", type="int", dest="maxLines",
                      help="[default: 1000000000")

    configParser = getConfigParser()
    section = "makerdsfromeland2"
    init = getConfigBoolOption(configParser, section, "init", True)
    doIndex = getConfigBoolOption(configParser, section, "doIndex", False)
    cachePages = getConfigIntOption(configParser, section, "cachePages", 100000)
    geneDataFileName = getConfigOption(configParser, section, "geneDataFileName", None)
    useOldDelimiter = getConfigBoolOption(configParser, section, "useOldDelimiter", False)
    pairID = getConfigOption(configParser, section, "pairID", None)
    maxLines = getConfigIntOption(configParser, section, "maxLines", 1000000000)
    extended = getConfigBoolOption(configParser, section, "extended", False)
    verbose = getConfigBoolOption(configParser, section, "verbose", False)

    parser.set_defaults(init=init, doIndex=doIndex, cachePages=cachePages,
                        geneDataFileName=geneDataFileName, useOldDelimiter=useOldDelimiter,
                        pairID=pairID, maxLines=maxLines, extended=extended, verbose=verbose)

    return parser


def makeRDSFromEland2(label, filename, outdbname, doIndex=False, delimiter="|", paired=False,
                      init=True, pairID="1", dataType="DNA", geneDataFileName=None,
                      cachePages=100000, maxLines=1000000000, extended=False, verbose=False):

    maxBorder = 0
    index = 0
    insertSize = 100000

    geneDict = {}
    if dataType == 'RNA':
        genedatafile = open(geneDataFileName)
        for line in genedatafile:
            fields = line.strip().split('\t')
            blockCount = int(fields[7])
            if blockCount < 2:
                continue

            uname = fields[0]
            chrom = fields[1]
            sense = fields[2]
            chromstarts = fields[8][:-1].split(',')
            for index in range(blockCount):
                chromstarts[index] = int(chromstarts[index])

            geneDict[uname] = (sense, blockCount, chrom, chromstarts)

        genedatafile.close()

    rds = ReadDataset.ReadDataset(outdbname, init, dataType, verbose=True)

    if cachePages > rds.getDefaultCacheSize():
        if init:
            rds.setDBcache(cachePages, default=True)
        else:
            rds.setDBcache(cachePages)

    if not init and doIndex:
        try:
            if rds.hasIndex():
                rds.dropIndex()
        except:
            if verbose:
                print "couldn't drop Index"

    propertyList = []
    for arg in sys.argv:
        if '::' in arg:
            (pname, pvalue) = arg.strip().split('::')
            if pname == 'flowcell' and paired:
                pvalue = pvalue + '/' + pairID

            propertyList.append((pname, pvalue))

    if len(propertyList) > 0:
        rds.insertMetadata(propertyList)

    infile = open(filename,'r')
    line = infile.readline()
    fields = line.split()
    readsize = len(fields[1])
    readsizeString = str(readsize)
    if dataType == 'RNA' and readsize > 32:
        splicesizeString = '32'
    else:
        splicesizeString = readsizeString

    print 'read size: %d bp' % readsize
    if init:
        rds.insertMetadata([('readsize', readsize)])
        rds.insertMetadata([('eland_mapped', 'True')])
        if extended:
            rds.insertMetadata([('eland_extended', 'True')])

        if paired:
            rds.insertMetadata([('paired', 'True')])

    trim = -4
    if dataType == 'RNA':
        maxBorder = readsize + trim

    insertList = []
    infile = open(filename,'r')
    print 'mapping unique reads...'
    lineIndex = 0
    for line in infile:
        lineIndex += 1
        if lineIndex > maxLines:
            break

        fields = line.split()
        if fields[2] in  ['QC','NM']:
            continue

        (matchType, bestMatch) = getUniqueMatch(fields[2])
        if matchType == -1:
            continue

        bestpos = []
        try:
            pos = fields[3].split(',')
        except:
            if verbose:
                print 'problem with line: %s' % line.strip()
            continue

        matchDict = {0:[], 1:[], 2:[], 3:[]}
        if len(pos) == 1:
            if 'splice' in pos:
                continue

            bestpos = pos
        else:
            currentChr = ''
            for apos in pos:
                if 'splice' in apos:
                    continue

                if ':' in apos:
                    (front, back) = apos.split(':')
                    currentChr = front
                else:
                    back = apos
                    apos = currentChr + ':' + apos

                if extended:
                    matchType = back.count('A') + back.count('C') + back.count('G') + back.count('T')
                    if matchType > 2:
                        matchType = 3
                else:
                    matchType = int(apos[-1])

                matchDict[matchType].append(apos)
                if bestMatch[matchType]:
                    bestpos.append(apos)

        # for padded reads, mapped read might have more mismatches!
        if len(bestpos) == 0:
            # let's not worry about these yet.
            if 'splice' in line:
                continue

            for matchType in [1, 2, 3]:
                if len(matchDict[matchType]) > 0:
                    if len(matchDict[matchType]) == 1 and 'splice' not in matchDict[matchType][0]:
                        bestpos = matchDict[matchType]
                    break

            if len(bestpos) == 0 and verbose:
                print "couldn't pick best read from line: %s" % line

        for apos in bestpos:
            try:
                (chrom, back) = apos.split(':')
            except:
                continue

            if 'splice' in chrom:
                continue

            if '/' in chrom:
                chromfields = chrom.split('/')
                chrom = chromfields[-1]

            if '.' in chrom:
                try:
                    (chrom, fileExt) = chrom.split('.')
                except:
                    if verbose:
                        print 'problem with chromosome on line %s' % line.strip()

                    continue

            if extended:
                if 'F' in back:
                    sense = '+'
                    (start, matchPart) = back.split('F')
                else:
                    sense = '-'
                    (start, matchPart) = back.split('R')

                start = int(start) 
                if matchPart == readsizeString:
                    matchType = ''
                else:
                    matchType = decodeMismatches(fields[1], matchPart)
            else:
                start = int(back[:-2])
                if back[-2] == 'F':
                    sense = '+'        
                else:
                    sense = '-'

            stop = int(start) + readsize - 1
            if paired:
                readID = label + '-' + str(lineIndex) + '/' + pairID
            else:
                readID = label + '-' + str(index)

            if len(chrom) > 0:
                insertList.append((readID, chrom, start, stop, sense, 1.0, '', matchType))

            if index % insertSize == 0:
                rds.insertUniqs(insertList)
                insertList = []
                print '.',
                sys.stdout.flush()

            index += 1

    if len(insertList) > 0:
        rds.insertUniqs(insertList)
        insertList = []

    print
    print '%d unique reads' % index
    infile.close()

    seenSpliceList = []
    if dataType == 'RNA':
        print 'mapping splices...'
        index = 0
        lineIndex = 0
        mapfile = open(filename,'r')
        for line in mapfile:
            lineIndex += 1
            if lineIndex > maxLines:
                break

            if 'splice' not in line:
                continue

            fields = line.strip().split()
            (matchType, bestMatch) = getUniqueMatch(fields[2])
            if matchType == -1:
                continue

            bestpos = []
            pos = fields[3].split(',')
            matchDict = {0:[], 1:[], 2:[], 3:[]}
            if len(pos) == 1:
                if 'chr' in pos:
                    continue

                bestpos = pos
            else:
                currentSplice = ''
                for apos in pos:
                    if 'splice' not in apos:
                        continue

                    if ':' in apos:
                        if delimiter == ':':
                            try:
                                (extmodel, spliceID, regionStart, thepos) = apos.split(':')
                            except:
                                try:
                                    (extmodel1, extmodel2, spliceID, regionStart, thepos) = apos.split(':')
                                    extmodel = extmodel1 + ':' + extmodel2
                                except:
                                    print 'warning: could not process splice %s' % apos
                                    continue

                            currentSplice = extmodel + ':' + spliceID + ':' + regionStart
                        else:
                            try:
                                (currentSplice, thepos) = apos.split(':')
                            except:
                                try:
                                    (extmodel1, restSplice, thepos) = apos.split(':')
                                    currentSplice = extmodel1 + ':' + restSplice
                                    (extmodel, spliceID, regionStart) = currentSplice.split(delimiter)
                                except:
                                    print 'warning: could not process splice %s' % apos
                                    continue
                    else:
                        thepos = apos
                        apos = currentSplice + ':' + apos

                    if extended:
                        matchType = thepos.count('A') + thepos.count('C') + thepos.count('G') + thepos.count('T')
                        if matchType > 2:
                            matchType = 3

                        # if readsize > 32, we risk loosing pefect matches that go beyond our expanded genome splices, so only ask for 32bp match
                        if thepos[:2] == splicesizeString:
                            matchType = 0
                    else:
                        matchType = int(apos[-1])

                    if bestMatch[matchType]:
                        bestpos.append(apos)

            # for padded reads, mapped read might have more mismatches!
            if len(bestpos) == 0:
                for matchType in [1, 2, 3]:
                    if len(matchDict[matchType]) > 0:
                        if len(matchDict[matchType]) == 1 and 'splice' in matchDict[matchType][0]:
                            bestpos = matchDict[matchType]

                        break
                if len(bestpos) == 0 and verbose:
                    print "couldn't pick best read from line: %s" % line

            for apos in bestpos:
                if delimiter == ':':
                    try:
                        (extmodel, spliceID, regionStart, thepos) = apos.split(':')
                    except:
                        try:
                            (extmodel1, extmodel2, spliceID, regionStart, thepos) = apos.split(':')
                            extmodel = extmodel1 + ':' + extmodel2
                        except:
                            print 'warning: could not process splice %s' % apos
                            continue
                else:
                    try:
                        (currentSplice, thepos) = apos.split(':')
                    except:
                        try:
                            (extmodel1, restSplice, thepos) = apos.split(':')
                            currentSplice = extmodel1 + ':' + restSplice
                        except:
                            print 'warning: could not process splice %s' % apos
                            continue

                    (extmodel, spliceID, regionStart) = currentSplice.split(delimiter)

                modelfields = extmodel.split('/')
                if len(modelfields) > 2:
                    model = string.join(modelfields[1:],'/')
                else:
                    model = modelfields[1]

                if model not in geneDict:
                    print fields
                    continue

                (sense, blockCount, chrom, chromstarts) = geneDict[model]
                if extended:
                    if 'F' in thepos:
                        rsense = '+'
                        (start, matchPart) = thepos.split('F')
                    else:
                        rsense = '-'
                        (start, matchPart) = thepos.split('R')

                    rstart = int(start) - 2 
                    if matchPart == readsizeString:
                        matchType = ''
                    elif matchPart[:2] == splicesizeString:
                        matchType = ''
                    else:
                        matchType = decodeMismatches(fields[1], matchPart)
                else:
                    rstart = int(thepos[:-2]) - 2
                    if thepos[-2] == 'F':
                        rsense = '+'
                    else:
                        rsense = '-'

                if trim <= rstart <= maxBorder:
                    pass
                else:
                    print rstart
                    continue

                currentSplice = model + delimiter + spliceID + delimiter + regionStart
                spliceID = int(spliceID)
                lefthalf = maxBorder - rstart
                if lefthalf < 1 or lefthalf > maxBorder:
                    continue

                righthalf = readsize - lefthalf
                startL = int(regionStart)  + rstart
                stopL = startL + lefthalf
                startR = chromstarts[spliceID + 1]
                stopR = chromstarts[spliceID + 1] + righthalf
                if paired:
                    readName = label + '-' + str(lineIndex) + '/' + pairID
                else:
                    readName = model + '-' + str(thepos)

                insertList.append((readName, chrom, startL, stopL, startR, stopR, rsense, 1.0, '', matchType))
                index += 1
                if index % insertSize == 0:
                    rds.insertSplices(insertList)
                    print '.',
                    sys.stdout.flush()
                    insertList = []

                if currentSplice not in seenSpliceList:
                    seenSpliceList.append(currentSplice)

        mapfile.close()
        if len(insertList) > 0:
            rds.insertSplices(insertList)
            insertList = []

        print
        print 'saw %d spliced reads accross %d distinct splices' % (index, len(seenSpliceList))

    infile = open(filename,'r')
    print 'mapping multireads...'
    lineIndex = 0
    origReadid = rds.getMultiCount()
    try:
        readid = int(origReadid) + 1
    except:
        readid = 0
        origReadid = 0

    print 'starting at %d' % (readid + 1)

    for line in infile:
        lineIndex += 1
        if lineIndex > maxLines:
            break

        fields = line.split()
        if len(fields) < 4:
            continue

        if fields[2] == 'QC' or fields[2] == 'NM' or fields[3] == '-':
            continue

        (zero, one, two) = fields[2].split(':')
        zero = int(zero)
        one = int(one)
        two = int(two)

        bestMatch = [False] * readsize
        if zero > 1:
            bestMatch[0] = True
        elif zero == 0 and one > 1:
            bestMatch[1] = True
        elif zero == 0 and one == 0 and two > 1:
            bestMatch[2] = True
        else:
            continue

        readcount = 0
        bestpos = []
        pos = fields[3].split(',')
        matchDict = {0:[], 1:[], 2:[], 3:[]}
        currentChr = ''
        for apos in pos:
            if ':' in apos:
                try:
                    (front, back) = apos.split(':')
                except:
                    if verbose:
                        print "problem splitting %s" % str(apos)
                    continue

                currentChr = front
            else:
                back = apos
                apos = currentChr + ':' + apos

            if extended:
                matchType = back.count('A') + back.count('C') + back.count('G') + back.count('T')
            else:
                matchType = int(apos[-1])

            try:
                matchDict[matchType].append(apos)
            except:
                matchDict[matchType] = [apos]

            if bestMatch[matchType]:
                bestpos.append(apos)

        # for padded reads, mapped read might have more mismatches!
        if len(bestpos) == 0:
            for matchType in [1, 2, 3]:
                if len(matchDict[matchType]) > 0:
                    if len(matchDict[matchType]) > 1:
                        noSplice = True
                        for arg in matchDict[matchType]:
                            if 'splice' in arg:
                                noSplice = False

                        if noSplice:
                            bestpos = matchDict[matchType]
                    break

            if len(bestpos) == 0 and verbose:
                print "couldn't pick best read from line: %s" % line
                continue

        hasSplice = False
        for apos in bestpos:
            if 'splice' in apos:
                hasSplice = True

        # do not allow multireads that can also map accross splices for now
        if hasSplice:
            if verbose:
                print "throwing out multiread because of splice conflict"
            continue

        if len(bestpos) > 0:
            readid += 1

        for apos in bestpos:
            readcount += 1
            (front, back) = apos.split(':')
            chrom = front[:-3]
            if extended:
                if 'F' in back:
                    sense = '+'
                    (start, matchPart) = back.split('F')
                else:
                    sense = '-'
                    (start, matchPart) = back.split('R')

                start = int(start)
                if matchPart == readsizeString:
                    matchType = ''
                else:
                    matchType = decodeMismatches(fields[1], matchPart)
            else:
                start = int(back[:-2])
                if back[-2] == 'F':
                    sense = '+'
                else:
                    sense = '-'

            stop = int(start) + readsize
            readName = '%dx%d' % (readid, len(bestpos))
            if paired:
                readName = label + '-' + str(lineIndex) + '/' + pairID + '::' + readName

            insertList.append((readName, chrom, start, stop, sense, 1.0/len(bestpos), '', matchType))
            if index % insertSize == 0:
                rds.insertMulti(insertList)
                insertList = []
                print '.',
                sys.stdout.flush()

            index += 1

    if len(insertList) > 0:
        rds.insertMulti(insertList)
        insertList = []

    print
    print '%d multireads' % (readid - origReadid)

    if doIndex:
        print 'building index....'
        rds.buildIndex(cachePages)


def getUniqueMatch(elandCode):
    (zero, one, two) = elandCode.split(':')
    zero = int(zero)
    one = int(one)
    two = int(two)
    bestMatch = [False, False, False, False]
    if zero == 1:
        bestMatch[0] = True
        matchType = 0
    elif zero == 0 and one == 1:
        bestMatch[1] = True
        matchType = 1
    elif zero == 0 and one == 0 and two == 1:
        bestMatch[2] = True
        matchType = 2
    else:
        matchType = -1
    
    return (matchType, bestMatch)


def decodeMismatches(origSeq, code):
    output = []
    number = '0'
    index = 0
    for pos in code:
        if pos.isdigit():
            number += pos
        else:   
            index += int(number) + 1
            origNT = origSeq[index - 1]
            output.append('%s%d%s' % (origNT, index, pos))
            number = '0'

    return string.join(output, ',')


if __name__ == "__main__":
    main(sys.argv)