##################################
#                                #
# Last modified 2017/07/05       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import os
from sets import Set


# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s BAMfilename reference.fa chrom.sizes minSingleStrandCounts minBothStrandsCounts outputfilename [-minFreq float] [-maxStrandBias float] [-chr chrN1(,chrN2....)] [-noNH samtools] [-region chr left right]' % sys.argv[0]
        print '\t -minSingleStrandCounts option: minimum counts on a single strand (either)'
        print '\t -minBothStrandCounts option: minimum counts on each strand'
        print '\t -maxStrandBias: maximum ratio between forward and reverse strand read counts'
        print '\t -minFreq: minimum ratio between supporting and total reads'
        print '\t!!!Note: indels will be ignored!!!'
        print '\t!!!Note: currently overlapping spliced reads may get counted twice toward coverage'
        sys.exit(1)
    
    BAM = sys.argv[1]
    fasta = sys.argv[2]
    chrominfo=sys.argv[3]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
    minSSCounts = int(sys.argv[4])
    minBSCounts = int(sys.argv[5])
    outfilename = sys.argv[6]

    doMinFreq = False
    if '-minFreq' in sys.argv:
        doMinFreq = True
        MinFreq = float(sys.argv[sys.argv.index('-minFreq') + 1])

    doMaxStrandBias = False
    if '-maxStrandBias' in sys.argv:
        doMaxStrandBias = True
        maxStrandBias = float(sys.argv[sys.argv.index('-maxStrandBias') + 1])

    doChrSubset=False
    if '-chr' in sys.argv:
        doChrSubset=True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''

    doRegion = False
    if '-region' in sys.argv:
        doRegion = True
        regionchr = sys.argv[sys.argv.index('-region') + 1]
        regionleft = int(sys.argv[sys.argv.index('-region') + 2])
        regionright = int(sys.argv[sys.argv.index('-region') + 3])
        chromInfoList = []
        chromInfoList.append((regionchr,regionleft,regionright))
        WantedChrDict = {}
        WantedChrDict[regionchr] = ''

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    samfile = pysam.Samfile(BAM, "rb" )
    try:
        print 'testing for NH tags presence'
        for alignedread in samfile.fetch():
            multiplicity = alignedread.opt('NH')
            print 'file has NH tags'
            break
    except:
        if '-noNH' in sys.argv:
            print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
            samtools = sys.argv[sys.argv.index('-noNH')+1]
            BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
            cmd = 'python ' + BAMpreporcessingScript + ' ' + BAM + ' ' + BAM + '.NH'
            os.system(cmd)
            cmd = 'rm ' + BAM
            os.system(cmd)
            cmd = 'mv ' + BAM + '.NH' + ' ' + BAM
            os.system(cmd)
            cmd = samtools + ' index ' + BAM
            os.system(cmd)
        elif doUniqueBAM:
            pass
        else:
            print 'no NH: tags in BAM file, exiting'
            sys.exit(1)

    samfile = pysam.Samfile(BAM, "rb" )

    outfile = open(outfilename, 'w')

    outline='#chr\tpos\tRef\tVariant\tForward_reads_supporting\tReverse_reads_supporting\tReads_not_supporting'
    outfile.write(outline+'\n')

    VariantsDict = {}

    RN=0
    for (chr,start,end) in chromInfoList:
        if doChrSubset:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        print chr, start, end        
        ReadDict = {}
        for alignedread in samfile.fetch(chr, start, end):
            fields=str(alignedread).split('\t')
            ID = fields[0]
            if ReadDict.has_key(ID):
                pass
            else:
                ReadDict[ID] = {}
            fields = str(alignedread).split('\t')
            if alignedread.is_read1:
                ReadDict[ID][1] = fields + [alignedread.opt('MD')]
            if alignedread.is_read2:
                ReadDict[ID][2] = fields + [alignedread.opt('MD')]
        CoverageDict = {}
        VariantDict = {}
        RN = 0
        for ID in ReadDict.keys():
            CoveredPositions = {}
#            CIGAR1 = ReadDict[ID][1][5]
#            CIGAR2 = ReadDict[ID][2][5]
#            if 'I' in CIGAR1 or 'I' in CIGAR2:
#                print CIGAR1, CIGAR2, ReadDict[ID][1][3], ReadDict[ID][2][3]
#                continue
            if ReadDict[ID].has_key(1):
                pos = int(ReadDict[ID][1][3]) + 1
                currReadPos = pos
                CIGAR = ReadDict[ID][1][5]
                CIGARblocks = CIGAR.replace('M','\t').replace('N','\t').replace('D','\t').strip().split('\t')
                for i in range(len(CIGARblocks)):
                    if i % 2 == 1:
                        currReadPos += int(CIGARblocks[i])
                    else:
                        for j in range(int(CIGARblocks[i])):
                            CoveredPositions[currReadPos + j] = 1
                        currReadPos += int(CIGARblocks[i])
            if ReadDict[ID].has_key(2):
                pos = int(ReadDict[ID][2][3]) + 1
                currReadPos = pos
                CIGAR = ReadDict[ID][2][5]
                CIGARblocks = CIGAR.replace('M','\t').replace('N','\t').replace('D','\t').strip().split('\t')
                for i in range(len(CIGARblocks)):
                    if i % 2 == 1:
                        currReadPos += int(CIGARblocks[i])
                    else:
                        for j in range(int(CIGARblocks[i])):
                            CoveredPositions[currReadPos + j] = 1
                        currReadPos += int(CIGARblocks[i])
            for pos in CoveredPositions.keys():
                if CoverageDict.has_key(pos):
                    pass
                else:
                    CoverageDict[pos] = 0
                CoverageDict[pos] += 1
            RN += 1
            vars1 = {}
            vars2 = {}
            if ReadDict[ID].has_key(1):
                RL1 = len(ReadDict[ID][1][9])
                pos1 = int(ReadDict[ID][1][3]) + 1
                if str(ReadDict[ID][1][-1]) == str(RL1):
                    pass
                else:
                    MD = ReadDict[ID][1][-1]
                    MDblocks = MD.replace('^','').replace('A','\t').replace('G','\t').replace('C','\t').replace('T','\t').split('\t')
                    CIGAR = ReadDict[ID][1][5]
                    CIGARblocks = CIGAR.replace('M','\t').replace('N','\t').replace('D','\t').strip().split('\t')
                    currReadPos = pos1
                    readPos = pos1
                    ReadToGenomePosDict = {}
                    for i in range(len(CIGARblocks)):
                        if i % 2 == 1:
                            currReadPos += int(CIGARblocks[i])
                        else:
                            for j in range(int(CIGARblocks[i])):
                                ReadToGenomePosDict[readPos + j] = currReadPos + j
                            readPos += int(CIGARblocks[i])
                            currReadPos += int(CIGARblocks[i])
                    ReadToGenomePosDict[readPos] = currReadPos
                    currReadPos = pos1
                    for i in range(len(MDblocks) - 1):
                        currReadPos += int(MDblocks[i]) + 1
                        if ReadDict[ID][1][9][currReadPos - pos1 - 1] != 'N' and ReadDict[ID][1][9][currReadPos - pos1 - 1] != 'n':
                            vars1[ReadToGenomePosDict[currReadPos]] = ReadDict[ID][1][9][currReadPos - pos1 - 1]
            if ReadDict[ID].has_key(2):
                RL2 = len(ReadDict[ID][2][9])
                pos2 = int(ReadDict[ID][2][3]) + 1
                if str(ReadDict[ID][2][-1]) == str(RL2):
                    pass
                else:
                    MD = ReadDict[ID][2][-1]
                    MDblocks = MD.replace('^','').replace('A','\t').replace('G','\t').replace('C','\t').replace('T','\t').split('\t')
                    CIGAR = ReadDict[ID][2][5]
                    CIGARblocks = CIGAR.replace('M','\t').replace('N','\t').replace('D','\t').strip().split('\t')
                    currReadPos = pos2
                    readPos = pos2
                    ReadToGenomePosDict = {}
                    for i in range(len(CIGARblocks)):
                        if i % 2 == 1:
                            currReadPos += int(CIGARblocks[i])
                        else:
                            for j in range(int(CIGARblocks[i])):
                                ReadToGenomePosDict[readPos + j] = currReadPos + j
                            readPos += int(CIGARblocks[i])
                            currReadPos += int(CIGARblocks[i])
                    ReadToGenomePosDict[readPos] = currReadPos
                    currReadPos = pos2
                    for i in range(len(MDblocks) - 1):
                        currReadPos += int(MDblocks[i]) + 1
                        if ReadDict[ID][2][9][currReadPos - pos2 - 1] != 'N' and ReadDict[ID][2][9][currReadPos - pos2 - 1] != 'n':
                            vars2[ReadToGenomePosDict[currReadPos]] = ReadDict[ID][2][9][currReadPos - pos2 - 1]
            if ReadDict[ID].has_key(1) and ReadDict[ID].has_key(2):
                if pos1 >= pos2:
                    if pos1 - pos2 > RL2:
                        pass 
                    else:
                        for i in range(pos1,pos2+RL2+1):
                            if vars1.has_key(i):
                                if vars2.has_key(i):
                                    pass
                                else:
                                    del vars1[i]
                            if vars2.has_key(i):
                                if vars1.has_key(i):
                                    pass
                                else:
                                    del vars2[i]
                if pos2 > pos1:
                    if pos2 - pos1 > RL1:
                        pass 
                    else:
                        for i in range(pos2,pos1+RL1+1):
                            if vars1.has_key(i):
                                if vars2.has_key(i):
                                    pass
                                else:
                                    del vars1[i]
                            if vars2.has_key(i):
                                if vars1.has_key(i):
                                    pass
                                else:
                                    del vars2[i]
            if ReadDict[ID].has_key(1):
                FLAGfields = FLAG(int(ReadDict[ID][1][1]))
                if 16 in FLAGfields:
                    s = '-'
                else:
                    s = '+'
            elif ReadDict[ID].has_key(2):
                FLAGfields = FLAG(int(ReadDict[ID][2][1]))
                if 16 in FLAGfields:
                    s = '+'
                else:
                    s = '-'
            for i in vars1.keys():
                if VariantDict.has_key(i):
                    pass
                else:
                    VariantDict[i]={}
                if VariantDict[i].has_key(vars1[i]):
                    pass
                else:
                    VariantDict[i][vars1[i]] = {}
                    VariantDict[i][vars1[i]]['+'] = []
                    VariantDict[i][vars1[i]]['-'] = []
                VariantDict[i][vars1[i]][s].append(ID)
            for i in vars2.keys():
                if VariantDict.has_key(i):
                    pass
                else:
                    VariantDict[i]={}
                if VariantDict[i].has_key(vars2[i]):
                    pass
                else:
                    VariantDict[i][vars2[i]] = {}
                    VariantDict[i][vars2[i]]['+'] = []
                    VariantDict[i][vars2[i]]['-'] = []
                VariantDict[i][vars2[i]][s].append(ID)

        VarPositions = VariantDict.keys()
        VarPositions.sort()

        print 'reads considered', len(ReadDict.keys())
        print 'variant positions', len(VarPositions)

        for i in VarPositions:
#            print i
#            if CoverageDict.has_key(i):
#                pass
#            else:
#                print 'skipping', i
#                continue
            for v in VariantDict[i].keys():
                plusCounts = len(list(Set(VariantDict[i][v]['+'])))
                minusCounts = len(list(Set(VariantDict[i][v]['-'])))
                if min(plusCounts,minusCounts) < minBSCounts:
                    continue
                if max(plusCounts,minusCounts) < minSSCounts:
                    continue
                if doMinFreq:
                    if (plusCounts + minusCounts)/(CoverageDict[i] + 0.0) < MinFreq:
                        continue
                if doMaxStrandBias:
                    if plusCounts/(0.0 + minusCounts) > maxStrandBias or minusCounts/(0.0 + plusCounts) > maxStrandBias:
                        continue
                outline = chr + '\t' + str(i) + '\t' + GenomeDict[chr][i-2] + '\t' + v + '\t' + str(plusCounts) + '\t' + str(minusCounts) + '\t' + str(CoverageDict[i] - minusCounts - plusCounts)
                outfile.write(outline + '\n')
            
    outfile.close()
            
run()
