##################################
#                                #
# Last modified 2022/07/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import os

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate
# 0x0800 2048 supplementary alignment

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024,2048]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s BAMfilename chrom.sizes outputfilename ' % sys.argv[0]
        sys.exit(1)

    BAM = sys.argv[1]
    chrominfo = sys.argv[2]
    chromInfoList = []
    chromInfoDict = {}
    linelist = open(chrominfo)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        start = 0
        end = int(fields[1])
        chromInfoList.append((chr,start,end))
        chromInfoDict[chr] = (start,end)
    outfilename = sys.argv[3]

    samfile = pysam.Samfile(BAM, "rb" )

    outfile = open(outfilename, 'w')

    RN=0
    for (chr,start,end) in chromInfoList:
        try:
            jj=0
            for alignedread in samfile.fetch(chr, start, end):
                jj+=1
                if jj==1:
                    break
        except:
            print 'problem with region:', chr, start, end, 'skipping'
            continue
        coverageDict = {}
        InDelDict = {}
        currentPos = 0
        chrStart = start
        chrEnd = end
        for alignedread in samfile.fetch(chr, start, end):
            try: 
                TTTTT = str(alignedread)
            except:
                print 'skipping read', chr, start, end, RN
                continue
            RN+=1
            if RN % 100000 == 0:
                print str(RN/1000000.) + ' M alignments processed', chr, currentPos, end
            fields = str(alignedread).split('\t')
            FLAGfields = FLAG(int(fields[1]))
            if 4 in FLAGfields:
                continue
            ID = fields[0]
            currentPos=alignedread.pos
            for (m,bp) in alignedread.cigar:
                if m == 0:
                    for j in range(currentPos,currentPos+bp):
                        if coverageDict.has_key(j+1):
                            coverageDict[j+1] += 1.0
                        else:
                            coverageDict[j+1] = 1.0
                elif m == 1:
                    if InDelDict.has_key(j+1):
                        InDelDict[j+1] += 1.0
                    else:
                        InDelDict[j+1] = 1.0
                elif m == 2:
                    for j in range(currentPos,currentPos+bp):
                        if InDelDict.has_key(j+1):
                            InDelDict[j+1] += 1.0
                        else:
                            InDelDict[j+1] = 1.0
                elif m == 3:
                     pass
                else:
                     continue
                currentPos=currentPos+bp
        posKeys=coverageDict.keys()
        posKeys.sort()
        if len(posKeys) > 0:
            initial = (posKeys[0],coverageDict[posKeys[0]])
            previous = (posKeys[0],coverageDict[posKeys[0]])
            for i in range(1,max(posKeys)+1):
                if coverageDict.has_key(i):
                    if InDelDict.has_key(i):
                        outline = chr + '\t' + str(i-1) + '\t' + str(i+1-1) + '\t' + str(InDelDict[i]/coverageDict[i])
                        outfile.write(outline+'\n')

    outfile.close()
            
run()
