##################################
#                                #
# Last modified 2020/05/20       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import os
from sets import Set

# FLAG field meaning:
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

# CIGAR field decoding:
# M 0 alignment match (can be a sequence match or mismatch)
# I 1 insertion to the reference
# D 2 deletion from the reference
# N 3 skipped region from the reference
# S 4 soft clipping (clipped sequences present in SEQ)
# H 5 hard clipping (clipped sequences NOT present in SEQ)
# P 6 padding (silent deletion from padded reference)
# = 7 sequence match
# X 8 sequence mismatch

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAMfilename reference.fa chrom.sizes outputfilename [-noMulti] [-chr chrN1(,chrN2....)] [-regions bed chr left right] [-region chr left right]' % sys.argv[0]
        print '\t!!!Note: this is for 10X scATAC only!'
        print '\t!!!Note: fragments with soft clippings will be ignored!!!'
        sys.exit(1)
    
    BAM = sys.argv[1]
    fasta = sys.argv[2]
    chrominfo=sys.argv[3]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
    outfilename = sys.argv[4]

    doNoMulti = False
    if '-noMulti' in sys.argv:
        doNoMulti = True

    doChrSubset = False
    if '-chr' in sys.argv:
        doChrSubset=True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''

    if '-region' in sys.argv:
        regionchr = sys.argv[sys.argv.index('-region') + 1]
        regionleft = int(sys.argv[sys.argv.index('-region') + 2])
        regionright = int(sys.argv[sys.argv.index('-region') + 3])
        chromInfoList = []
        chromInfoList.append((regionchr,regionleft,regionright))
        WantedChrDict = {}
        WantedChrDict[regionchr] = ''

    if '-regions' in sys.argv:
        chromInfoList = []
        regionchrBED = sys.argv[sys.argv.index('-regions') + 1]
        regionchrID = int(sys.argv[sys.argv.index('-regions') + 2])
        regionleftID = int(sys.argv[sys.argv.index('-regions') + 3])
        regionrightID = int(sys.argv[sys.argv.index('-regions') + 4])
        WantedChrDict = {}
        linelist = open(regionchrBED)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            regionchr = fields[regionchrID]
            regionleft = int(fields[regionleftID])
            regionright = int(fields[regionrightID])
            chromInfoList.append((regionchr,regionleft,regionright))
            WantedChrDict[regionchr] = ''

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    print 'finished inputting genomic sequence'

#    samfile = pysam.Samfile(BAM, "rb" )
#    try:
#        print 'testing for NH tags presence'
#        for alignedread in samfile.fetch():
#            multiplicity = alignedread.opt('NH')
#            print 'file has NH tags'
#            break
#    except:
#        if '-noNH' in sys.argv:
#            print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
#            samtools = sys.argv[sys.argv.index('-noNH')+1]
#            BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
#            cmd = 'python ' + BAMpreporcessingScript + ' ' + BAM + ' ' + BAM + '.NH'
#            os.system(cmd)
#            cmd = 'rm ' + BAM
#            os.system(cmd)
#            cmd = 'mv ' + BAM + '.NH' + ' ' + BAM
#            os.system(cmd)
#            cmd = samtools + ' index ' + BAM
#            os.system(cmd)
#        elif doUniqueBAM:
#            pass
#        else:
#            print 'no NH: tags in BAM file, exiting'
#            sys.exit(1)

    samfile = pysam.Samfile(BAM, "rb" )

    outfile = open(outfilename, 'w')

    outline='#CellBarcode\tchr\tpos\tREF\tALT\tsupporting_reads\ttotal_reads'
    outfile.write(outline + '\n')

    CoverageDict = {}
    VariantDict = {}

    RN=0
    for (chr,start,end) in chromInfoList:
        print chr, start, end        
        if doChrSubset:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        ReadDict = {}
        for alignedread in samfile.fetch(chr, start, end):
            if doNoMulti:
                if alignedread.opt('NH') > 1:
                    continue
            fields = str(alignedread).split('\t')
            ID = fields[0]
            if ReadDict.has_key(ID):
                pass
            else:
                ReadDict[ID] = {}
            if alignedread.is_read1:
                try:
                    MD = alignedread.opt('MD')
                except:
                    MD = 'nan'
                ReadDict[ID][1] = fields + [alignedread.cigar] + [alignedread.opt('CB')] + [MD]
#                ReadDict[ID][1] = fields + [alignedread.opt('CB')] + [MD] + [alignedread.opt('UB')]
            if alignedread.is_read2:
                try:
                    MD = alignedread.opt('MD')
                except:
                    MD = 'nan'
                ReadDict[ID][2] = fields + [alignedread.cigar] + [alignedread.opt('CB')] + [MD]
#                ReadDict[ID][2] = fields + [alignedread.opt('CB')] + [MD] + [alignedread.opt('UB')]
        RN = 0

        for ID in ReadDict.keys():
            if ReadDict[ID].has_key(1) and ReadDict[ID].has_key(2):
                pass
            else:
                continue
            if ReadDict[ID][1][-1] == 'nan' or ReadDict[ID][2][-1] == 'nan':
                continue
            CoveredPositions = {}
            CIGAR1 = ReadDict[ID][1][5]
            CIGAR2 = ReadDict[ID][2][5]
            if 'S' in CIGAR1 or 'S' in CIGAR2:
                 continue

            L1 = int(ReadDict[ID][1][3])
            R1 = int(ReadDict[ID][1][3]) + len(ReadDict[ID][1][9])
            L2 = int(ReadDict[ID][2][3])
            R2 = int(ReadDict[ID][2][3]) + len(ReadDict[ID][2][9])
            L = min(L1,R1,L2,R2)
            R = max(L1,R1,L2,R2)

            CoveredPositions = []

            CB = ReadDict[ID][1][-2]
            if CoverageDict.has_key(CB):
                pass
            else:
                CoverageDict[CB] = {}
            if VariantDict.has_key(CB):
                pass
            else:
                VariantDict[CB] = []

            seq1 = ReadDict[ID][1][9]
            seq2 = ReadDict[ID][2][9]

            pos1 = int(ReadDict[ID][1][3]) + 1
            currReadPos = pos1
            CIGAR1 = ReadDict[ID][1][-3]
            DISdict = {}
            DIS = 0
            B = 0
            for (CIG, bp) in CIGAR1:
                if CIG == 4:
                    for i in range(bp):
                        DISdict[B] = DIS
                        DIS -= 1
                        B += 1
                    pass
                if CIG == 2:
                    DISdict[B] = DIS
#                    DIS += 1
                    VariantDict[CB].append((L,R,chr,currReadPos,'D',bp))
                    CoveredPositions.append(currReadPos)
                    currReadPos += bp
                    pass
                if CIG == 0:
                    for j in range(bp):
                        DISdict[B] = DIS
                        B += 1
                        CoveredPositions.append(currReadPos + j)
                    currReadPos += bp
                    pass
                if CIG == 1:
                    VariantDict[CB].append((L,R,chr,currReadPos,'I',seq1[currReadPos - pos1 + DIS - 1:currReadPos - pos1 + DIS + bp]))
                    for j in range(bp):
                        DISdict[B] = DIS
                        DIS -= 1
                        B += 1
                    CoveredPositions.append(currReadPos)
                    pass
                if CIG == 3:
                    currReadPos += bp
                    pass

            MD1 = ReadDict[ID][1][-1]
            MDblocks = MD1.replace('^','\t^').replace('A','\t').replace('G','\t').replace('C','\t').replace('T','\t').split('\t')
            CP = pos1
            B = 0
            for i in range(len(MDblocks) - 1):
                if MDblocks[i] == '^':
                    CP += 1
#                    B += 1
                elif MDblocks[i] == '':
                    CP += 1
#                    B -= 1
                else:
                    if MDblocks[i+1] == '^':
                        CP += int(MDblocks[i])
                        B += int(MDblocks[i])
#                        continue
                    else:
                        CP += int(MDblocks[i]) + 1
                        B += int(MDblocks[i]) + 1
                        try:
                            DISCP = B-1 - DISdict[B-1]
                        except:
                            print 'error, exiting:'
                            print MD1, B
                            print MDblocks
                            print CIGAR1
                            print ReadDict[ID][1]
                            print DISdict
                            sys.exit(1)
                        VariantDict[CB].append((L,R,chr,CP-1,'MM',seq1[DISCP]))

            pos2 = int(ReadDict[ID][2][3]) + 1
            currReadPos = pos2
            CIGAR2 = ReadDict[ID][2][-3]
            DISdict = {}
            DIS = 0
            B = 0
            for (CIG, bp) in CIGAR2:
                if CIG == 4:
                    for i in range(bp):
                        DISdict[B] = DIS
                        DIS -= 1
                        B += 1
                    pass
                if CIG == 2:
                    DISdict[B] = DIS
#                    DIS += 1
                    VariantDict[CB].append((L,R,chr,currReadPos,'D',bp))
                    CoveredPositions.append(currReadPos)
                    currReadPos += bp
                    pass
                if CIG == 0:
                    for j in range(bp):
                        DISdict[B] = DIS
                        B += 1
                        CoveredPositions.append(currReadPos + j)
                    currReadPos += bp
                    pass
                if CIG == 1:
                    VariantDict[CB].append((L,R,chr,currReadPos,'I',seq2[currReadPos - pos2 + DIS - 1:currReadPos - pos2 + DIS + bp]))
                    for j in range(bp):
                        DISdict[B] = DIS
                        DIS -= 1
                        B += 1
                    pass
                if CIG == 3:
                    currReadPos += bp
                    pass

            MD2 = ReadDict[ID][2][-1]
            MDblocks = MD2.replace('^','\t^').replace('A','\t').replace('G','\t').replace('C','\t').replace('T','\t').split('\t')
            CP = pos2
            B = 0
            for i in range(len(MDblocks) - 1):
                if MDblocks[i] == '^':
                    CP += 1
#                    B -= 1
                elif MDblocks[i] == '':
                    CP += 1
#                    B -= 1
                else:
                    if MDblocks[i+1] == '^':
                        CP += int(MDblocks[i])
                        B += int(MDblocks[i])
                        continue
                    else:
                        CP += int(MDblocks[i]) + 1
                        B += int(MDblocks[i]) + 1
                        try:
                            DISCP = B-1 - DISdict[B-1]
                        except:
                            print 'error, exiting:'
                            print MD2, B
                            print MDblocks
                            print CIGAR2
                            print ReadDict[ID][2]
                            print DISdict
                            sys.exit(1)
                        VariantDict[CB].append((L,R,chr,CP-1,'MM',seq2[DISCP]))
                   

            CoveredPositions = list(Set(CoveredPositions))
            for POS in CoveredPositions:
                if CoverageDict[CB].has_key(POS):
                    pass
                else:
                    CoverageDict[CB][POS] = 0
                CoverageDict[CB][POS] += 1

    for CB in VariantDict.keys():
        VariantDict[CB] = list(Set(VariantDict[CB]))
#        print VariantDict[CB]
        VarDict = {}
        for (L,R,chr,CP,V,VV) in VariantDict[CB]:
            if V == 'MM':
                variant = (chr,CP,GenomeDict[chr][CP-1:CP],VV)
            if V == 'D':
                variant = (chr,CP,GenomeDict[chr][CP-1:CP+VV],GenomeDict[chr][CP:CP+1])
            if V == 'I':
                variant = (chr,CP,GenomeDict[chr][CP-1:CP],VV)
            if VarDict.has_key(variant):
                pass
            else:
                VarDict[variant] = 0
            VarDict[variant] += 1
#        print CoverageDict[CB]
#        print CB
        varkeys = VarDict.keys()
        varkeys.sort()
        for (chr,CP,REF,ALT) in varkeys:
            outline = CB + '\t' + chr + '\t' + str(CP) + '\t' + REF.upper() + '\t' + ALT.upper()
            outline = outline + '\t' + str(VarDict[(chr,CP,REF,ALT)])
            outline = outline + '\t' + str(CoverageDict[CB][CP])
            outfile.write(outline + '\n')
            
    outfile.close()
            
run()
