##################################
#                                #
# Last modified 2019/09/11       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate
# 0x0800 2048 supplementary alignment

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024,2048]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 1:
        print 'usage: python %s outputfilename [-stranded + | -] [-end2only] [-end1only] [-readLength min max] [-chr chrN1(,chrN2....)] [-absValue] [-uniqueBAM]' % sys.argv[0]
        print '\tthe script assumes stdin'
        print '\tthe script assumes alignments with NH tags that need not be sorted'
        print '\tsplit by chromosome for memory efficiency'
        print '\tNote: this script will not normalize to RPMs!!!'
        print '\tDo not run in stranded mode on split-read alignments!!'
        sys.exit(1)
    
    doSingleBP=False
    if '-singlebasepair' in sys.argv:
        doSingleBP=True

    outfilename = sys.argv[1]

    doEnd1Only = False
    doEnd2Only = False
    if '-end1only' in sys.argv and '-end2only' in sys.argv:
        print 'both -end1only and -end2only option specified, a logical impossiblity, exiting'
        sys,exit(1)

    if '-end1only' in sys.argv:
        doEnd1Only = True
        print 'will only consider the first end of read pairs'

    if '-end2only' in sys.argv:
        doEnd2Only = True 
        print 'will only consider the second end of read pairs'
    
    doReadLength=False
    if '-readLength' in sys.argv:
        doReadLength=True
        minRL = int(sys.argv[sys.argv.index('-readLength')+1])
        maxRL = int(sys.argv[sys.argv.index('-readLength')+2])
        print 'will only consider reads between', minRL, 'and', maxRL, 'bp length'

    doAbs = False
    if '-absValue' in sys.argv:
        doAbs = True
        print 'will output absolute values'

    doChrSubset=False
    if '-chr' in sys.argv:
        doChrSubset = True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''
        print WantedChrDict

    doStranded=False
    if '-stranded' in sys.argv:
        doStranded=True
        strand=sys.argv[sys.argv.index('-stranded')+1]
        print 'will only consider', strand, 'strand reads'

    CoverageDict = {}

    outfile = open(outfilename, 'w')

    RN = 0
    for line in sys.stdin:
        if line.startswith('@SQ'):
            continue
        if line.startswith('@HD'):
            continue
        if line.startswith('@PG'):
            continue
        RN += 1
        if RN % 1000000 == 0:
            print str(RN/1000000) + 'M alignments processed'
#            print 'memory usage:', sys.getsizeof(CoverageDict)
            print fields
        fields = line.strip().split('\t')
        chr = fields[2]
        if chr == '*':
            continue
        if doChrSubset:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        if doReadLength:
            if len(fields[9]) > maxRL:
                continue
            if len(fields[9]) > minRL:
                continue
        FLAGfields = FLAG(int(fields[1]))
        if doEnd1Only:
            if 128 in FLAGfields:
                continue
        if doEnd2Only:
            if 64 in FLAGfields:
                continue
        multiplicity = int(line.strip().split('NH:i:')[1].split('\t')[0])
        scaleby = 1.0/multiplicity
        if 16 in FLAGfields:
            s = '-'
        else:
            s = '+'
        if doStranded:
            if s != strand:
                continue
            if s == '-' and not doAbs:
                scaleby = 0 - scaleby
        pos = int(fields[3])
        if s == '-':
            pos = pos + len(fields[9])
        if CoverageDict.has_key(chr):
            pass
        else:
            CoverageDict[chr] = {}
        if CoverageDict[chr].has_key(pos):
            pass
        else:
            CoverageDict[chr][pos] = 0
        CoverageDict[chr][pos] += scaleby

    chromosomes = CoverageDict.keys()
    chromosomes.sort()
    for chr in chromosomes:
        positions = CoverageDict[chr].keys()
        positions.sort()
        initialPos = positions[0]
        currentPos = positions[0]
        currentScore = CoverageDict[chr][currentPos]
        for pos in positions:
            if pos == currentPos:
                continue
            if (pos == currentPos + 1) and (CoverageDict[chr][pos] == currentScore):
                currentPos = pos
            else:
                if initialPos == currentPos:
                    currentPos += 1
                outline = chr + '\t' + str(initialPos) + '\t' + str(currentPos) + '\t' + str(currentScore)
                outfile.write(outline + '\n')
                initialPos = pos
                currentPos = pos
                currentScore = CoverageDict[chr][currentPos]
        if initialPos == currentPos:
            currentPos += 1
        outline = chr + '\t' + str(initialPos) + '\t' + str(currentPos) + '\t' + str(currentScore)
        outfile.write(outline + '\n')

    outfile.close()
            
run()
