##################################
#                                #
# Last modified 10/30/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s BAM chrom.sizes.reference chrom.sizes vcf1 fieldID1 ID1 vcf2 fieldID2 ID2 outfile_prefix [-DP minDP] [-GQ minGQ] [-addChr] [-RPM] [-stranded - | +] [-chr chrN1(,chrN2....)] -noInDels' % sys.argv[0]
        print '\tthis script will take a BAM and output three wiggle files - one with the reads unique to each haplotype and one with the common reads (all normalized to the total number of reads if the -RPM option is used' 
        print '\tthis script is intended to work on unspliced alignments only - the genome phasing may not work if longe splices are present' 
        print '\tit is assumed chromosomes are specified as follows: chr::haplotype/strain' 
        print '\tit is assumed the BAM file has NH tags and mulitplicity up to 2 is allowed, with NH:i:2 reads corresponding to reads mapping to both haplotypes and NH:i:1 reads mapping to only one' 
        sys.exit(1)

    BAM = sys.argv[1]
    chromsizesRef = sys.argv[2]
    chromsizesVar = sys.argv[3]
    VCF1 = sys.argv[4]
    fieldID1 = int(sys.argv[5])
    ID1 = sys.argv[6]
    VCF2 = sys.argv[7]
    fieldID2 = int(sys.argv[8])
    ID2 = sys.argv[9]
    outprefix = sys.argv[10] 

    doNoInDels = False
    if '-noInDels' in sys.argv:
        print 'will omit indels'
        doNoInDels = True

    chromInfoListDict = {}
    linelist=open(chromsizesRef)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoListDict[chr] = end

    chromInfoList = []
    linelist=open(chromsizesVar)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    doGQ=False
    if '-GQ' in sys.argv:
        doGQ=True
        minGQ=int(sys.argv[sys.argv.index('-GQ')+1])

    doDP=False
    if '-DP' in sys.argv:
        doDP=True
        minDP=int(sys.argv[sys.argv.index('-DP')+1])

    doAddChr = False
    if '-addChr' in sys.argv:
        doAddChr = True

    doSingleBP=False
    if '-singlebasepair' in sys.argv:
        doSingleBP=True

    doRPM=False
    if '-RPM' in sys.argv:
        doRPM=True

    doStranded=False
    if '-stranded' in sys.argv:
        doStranded=True
        strand=sys.argv[sys.argv.index('-stranded')+1]
        print 'will only consider', strand, 'strand reads'

    doChrSubset=False
    if '-chr' in sys.argv:
        doChrSubset=True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''

    VariantDict1={}
    linelist=open(VCF1)
    i=0
    j=0
#    outfile = open('z1.vcf','w')
    for line in linelist:
        if line.startswith('#'):
#            outfile.write(line)
            continue
        i+=1
        fields=line.strip().split('\t')
        if doDP:
            if fields[7].startswith('DP='):
                DP = int(fields[7].split('DP=')[1].split(';')[0])
            else:
                DP = int(fields[7].split(';DP=')[1].split(';')[0])
            if DP < minDP:
                continue
        if doAddChr:
            chr='chr'+fields[0]
        else:
            chr = fields[0]
        if VariantDict1.has_key(chr):
            pass
        else:
            VariantDict1[chr]={}
        pos=int(fields[1])-1
        ref=fields[3]
        variants=fields[4].split(',')
        GQfield = fields[8].split(':').index('GQ')
        GTfield = fields[8].split(':').index('GT')
        if fields[fieldID1] == './.':
            continue
        GQ = fields[fieldID1].split(':')[GQfield]
        GT = fields[fieldID1].split(':')[GTfield]
        if doGQ:
            if int(GQ) < minGQ:
                continue
        alleles = GT.split('/')
        if alleles[0] != alleles[1]:
            continue
        if alleles[0] == '0':
            continue
#        outfile.write(line)
        if doNoInDels:
            if len(ref) != len(variants[int(alleles[0])-1]):
                continue
        VariantDict1[chr][pos] = (ref,variants[int(alleles[0])-1])
        j+=1

#    outfile.close()

    print 'parsed first vcf files', j, i

    VariantDict2={}
    linelist=open(VCF2)
    i=0
    j=0
#    outfile = open('z2.vcf','w')
    for line in linelist:
        if line.startswith('#'):
#            outfile.write(line)
            continue
        i+=1
        fields=line.strip().split('\t')
        if doDP:
            if fields[7].startswith('DP='):
                DP = int(fields[7].split('DP=')[1].split(';')[0])
            else:
                DP = int(fields[7].split(';DP=')[1].split(';')[0])
            if DP < minDP:
                continue
        if doAddChr:
            chr='chr'+fields[0]
        else:
            chr = fields[0]
        if VariantDict2.has_key(chr):
            pass
        else:
            VariantDict2[chr]={}
        pos=int(fields[1])-1
        ref=fields[3]
        variants=fields[4].split(',')
        GQfield = fields[8].split(':').index('GQ')
        GTfield = fields[8].split(':').index('GT')
        if fields[fieldID2] == './.':
            continue
        GQ = fields[fieldID2].split(':')[GQfield]
        GT = fields[fieldID2].split(':')[GTfield]
        if doGQ:
            if int(GQ) < minGQ:
                continue
        alleles = GT.split('/')
        if alleles[0] != alleles[1]:
            continue
        if alleles[0] == '0':
            continue
#        outfile.write(line)
        if doNoInDels:
            if len(ref) != len(variants[int(alleles[0])-1]):
                continue
        VariantDict2[chr][pos] = (ref,variants[int(alleles[0])-1])
        j+=1
    
#    outfile.close()

    print 'parsed second vcf files', j, i

    TotalNumberRead=0

    samfile = pysam.Samfile(BAM, "rb" )
    RN=0
    for (chr,start,end) in chromInfoList:
        RN+=1
        if RN % 5000000 == 0:
            print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        for alignedread in samfile.fetch(chr, start, end):
            RN+=1
            if RN % 5000000 == 0:
                print 'counting total number of reads', str(RN/1000000) + 'M alignments processed', chr, alignedread.pos, end
            multiplicity = alignedread.opt('NH')
            TotalNumberRead += 1.0/multiplicity

    TotalNumberRead = round(TotalNumberRead)

    print 'found', TotalNumberRead, 'reads'
    normFactor = TotalNumberRead/1000000.
    print 'RPM normalization Factor =', normFactor

    outfile1 = open(outprefix + '.' + ID1 + '.wig', 'w')
    outfile2 = open(outprefix + '.' + ID2 + '.wig', 'w')
    outfile12 = open(outprefix + '.common.wig', 'w')
    
    RN=0
    c=0 
    chromosomes = chromInfoListDict.keys()
    chromosomes.sort()
    for real_chr in chromosomes:
        c+=1
        coverageDict1={}
        coverageDict2={}
        coverageDict12={}
        if doChrSubset:
            if WantedChrDict.has_key(real_chr):
                pass
            else:
                continue
        currentPos=0
        VD = VariantDict2
        chr = real_chr + '::' + ID2
        if VD.has_key(real_chr):
            VD[real_chr][chromInfoListDict[real_chr]] = ('A','AA')
            variants = VD[real_chr].keys()
            variants.sort()
            currentALT = 0
            phase = 0
            for pos in variants:
                (ref,var) = VD[real_chr][pos]
                if len(ref) == len(var):
                    continue
                else:
#                    print '===', chr, currentALT, pos, phase, pos + phase
                    for alignedread in samfile.fetch(chr, currentALT, pos + phase):
                        RN+=1
                        if RN % 5000000 == 0:
                            print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
                        fields=str(alignedread).split('\t')
                        multiplicity = alignedread.opt('NH')
                        scaleby=1.0/multiplicity
                        FLAGfields = FLAG(int(fields[1]))
                        if 16 in FLAGfields:
                            s = '-'
                        else:
                            s = '+'
                        if doStranded:
                            if s!=strand:
                                continue
                        currentPos = alignedread.pos - phase
                        if multiplicity == 2:
                            for (m,bp) in alignedread.cigar:
                                if m == 0:
                                    for j in range(currentPos,currentPos+bp):
                                        if coverageDict12.has_key(j+1):
                                            coverageDict12[j+1]+=scaleby
                                        else:
                                            coverageDict12[j+1]=scaleby
                                elif m == 2:
                                    pass
                                elif m == 3:
                                    pass
                                else:
                                    continue
                                currentPos=currentPos+bp
                        if multiplicity == 1:
                            for (m,bp) in alignedread.cigar:
                                if m == 0:
                                    for j in range(currentPos,currentPos+bp):
                                        if coverageDict2.has_key(j+1):
                                            coverageDict2[j+1]+=scaleby
                                        else:
                                            coverageDict2[j+1]=scaleby
                                elif m == 2:
                                    pass
                                elif m == 3:
                                    pass
                                else:
                                    continue
                                currentPos=currentPos+bp
                    currentALT = pos + phase
                    phase += (len(var) - len(ref))
#                    print 'VD2', chr, real_chr, pos, ref, var, currentPos, currentALT, phase
        else:
            for alignedread in samfile.fetch(chr, start, end):
                RN+=1
                if RN % 5000000 == 0:
                    print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
                fields=str(alignedread).split('\t')
                multiplicity = alignedread.opt('NH')
                scaleby=1.0/multiplicity
                FLAGfields = FLAG(int(fields[1]))
                if 16 in FLAGfields:
                    s = '-'
                else:
                    s = '+'
                if doStranded:
                    if s!=strand:
                        continue
                currentPos=alignedread.pos
                for (m,bp) in alignedread.cigar:
                    if m == 0:
                        for j in range(currentPos,currentPos+bp):
                            if coverageDict12.has_key(j+1):
                                coverageDict12[j+1]+=scaleby
                            else:
                                coverageDict12[j+1]=scaleby
                    elif m == 2:
                        pass
                    elif m == 3:
                        pass
                    else:
                        continue
                    currentPos=currentPos+bp
        posKeys=coverageDict2.keys()
        posKeys.sort()
        print 'strain2:', len(posKeys), 
        if len(posKeys) == 0:
            pass
        else:
            initial=(posKeys[0],coverageDict2[posKeys[0]])
            previous=(posKeys[0],coverageDict2[posKeys[0]])
            written=['']
            for i in posKeys[1:len(posKeys)]:
                if previous[0]+1 == i and previous[1]==coverageDict2[i]:
                    previous=(i,coverageDict2[i])
                else:
                    if written[0]==initial[0]:
                        print '####', written, initial, previous
                    if doStranded and strand == '-':
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    else:
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    written=(initial[0],previous[0]+1)
                    outfile2.write(outline+'\n')
                    initial=(i,coverageDict2[i])
                    previous=(i,coverageDict2[i])
        currentPos=0
        VD = VariantDict1
        chr = real_chr + '::' + ID1
        if VD.has_key(real_chr):
            VD[real_chr][chromInfoListDict[real_chr]] = ('A','AA')
            variants = VD[real_chr].keys()
            variants.sort()
            currentALT = 0
            phase = 0
            for pos in variants:
                (ref,var) = VD[real_chr][pos]
                if len(ref) == len(var):
                    continue
                else:
                    for alignedread in samfile.fetch(chr, currentALT, pos + phase):
                        RN+=1
                        if RN % 5000000 == 0:
                            print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
                        fields=str(alignedread).split('\t')
                        multiplicity = alignedread.opt('NH')
                        scaleby=1.0/multiplicity
                        FLAGfields = FLAG(int(fields[1]))
                        if 16 in FLAGfields:
                            s = '-'
                        else:
                            s = '+'
                        if doStranded:
                            if s != strand:
                                continue
                        currentPos = alignedread.pos - phase
                        if multiplicity == 2:
                            for (m,bp) in alignedread.cigar:
                                if m == 0:
                                    for j in range(currentPos,currentPos+bp):
                                        if coverageDict12.has_key(j+1):
                                            coverageDict12[j+1]+=scaleby
                                        else:
                                            coverageDict12[j+1]=scaleby
                                elif m == 2:
                                    pass
                                elif m == 3:
                                    pass
                                else:
                                    continue
                                currentPos=currentPos+bp
                        if multiplicity == 1:
                            for (m,bp) in alignedread.cigar:
                                if m == 0:
                                    for j in range(currentPos,currentPos+bp):
                                        if coverageDict1.has_key(j+1):
                                            coverageDict1[j+1]+=scaleby
                                        else:
                                            coverageDict1[j+1]=scaleby
                                elif m == 2:
                                    pass
                                elif m == 3:
                                    pass
                                else:
                                    continue
                                currentPos=currentPos+bp
                    currentALT = pos + phase
                    phase += (len(var) - len(ref))
        posKeys=coverageDict1.keys()
        posKeys.sort()
        print 'strain1:', len(posKeys), 
        if len(posKeys) == 0:
            pass
        else:
            initial=(posKeys[0],coverageDict1[posKeys[0]])
            previous=(posKeys[0],coverageDict1[posKeys[0]])
            written=['']
            for i in posKeys[1:len(posKeys)]:
                if previous[0]+1 == i and previous[1]==coverageDict1[i]:
                    previous=(i,coverageDict1[i])
                else:
                    if written[0]==initial[0]:
                        print '####', written, initial, previous
                    if doStranded and strand == '-':
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    else:
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    written=(initial[0],previous[0]+1)
                    outfile1.write(outline+'\n')
                    initial=(i,coverageDict1[i])
                    previous=(i,coverageDict1[i])
        posKeys=coverageDict12.keys()
        posKeys.sort()
        print 'both strains', len(posKeys)
        if len(posKeys) == 0:
            pass
        else:
            initial=(posKeys[0],coverageDict12[posKeys[0]])
            previous=(posKeys[0],coverageDict12[posKeys[0]])
            written=['']
            for i in posKeys[1:len(posKeys)]:
                if previous[0]+1 == i and previous[1]==coverageDict12[i]:
                    previous=(i,coverageDict12[i])
                else:
                    if written[0]==initial[0]:
                        print '####', written, initial, previous
                    if previous[0] >= chromInfoListDict[real_chr]:
                        continue
                    if doStranded and strand == '-':
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t-'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    else:
                        if doRPM:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]/normFactor).split('.')[0] + '.' + str(initial[1]/normFactor).split('.')[1][0:4]
                        else:
                            outline=real_chr+'\t'+str(initial[0]-1)+'\t'+str(previous[0]+1-1)+'\t'+str(initial[1]).split('.')[0] + '.' + str(initial[1]).split('.')[1][0:4]
                    written=(initial[0],previous[0]+1)
                    outfile12.write(outline+'\n')
                    initial=(i,coverageDict12[i])
                    previous=(i,coverageDict12[i])

    outfile1.close()
    outfile2.close()
    outfile12.close()
    

run()