##################################
#                                #
# Last modified 08/17/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
from sets import Set

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList


def run():

    if len(sys.argv) < 5:
        print 'usage: python %s BAM chrom.sizes <junctions.bed | junctions.juncs | strandedFR | strandedSR > <bed | juncs | strandedFR | strandedSR> outfile [-nomulti]' % sys.argv[0]

        sys.exit(1)

    bam = sys.argv[1]
    junctions = sys.argv[3]
    juncstype = sys.argv[4]
    outputfilename = sys.argv[5]

    chromInfoList=[]
    linelist=open(sys.argv[2])
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    noMulti=False
    if '-nomulti' in sys.argv:
        noMulti=True

    JunctionDict={}
    doStranded = False
    doStrandFR = False
    doStrandSR = False
    print junctions
    if junctions == 'strandedFR':
        doStranded = True
        doStrandFR = True
        doStrandSR = False
        print 'will treat data as stranded, using the first reads as the strand-determining one'
    elif junctions == 'strandedSR':
        doStranded = True
        doStrandSR = True
        doStrandFR = False
        print 'will treat data as stranded, using the second read as the strand-determining one'
    else:
        linelist = open(junctions)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            if juncstype == 'bed':
                fields=line.strip().split('\t')
                if len(fields)<11:
                    continue
                chr = fields[0]
                leftShift = int(fields[10].split(',')[0])
                rightShift = int(fields[10].split(',')[1])
                left=int(fields[1]) + leftShift - 1
                right=int(fields[2]) - rightShift
                strand=fields[5]
                JunctionDict[(chr,left,right,strand)]={}
            if juncstype == 'juncs':
                chr = fields[0]
                strand = fields[3]
                left = int(fields[1])
                right = int(fields[2])
                JunctionDict[(chr,left,right,strand)]={}
    i=0
    samfile = pysam.Samfile(bam, "rb" )
    for (chr,start,end) in chromInfoList:
        print chr,start,end
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        for alignedread in samfile.fetch(chr, start, end):
            i+=1
            if i % 5000000 == 0:
                print str(i/1000000) + 'M alignments processed'
            if len(alignedread.cigar) == 1:
                continue 
            MAPQ=alignedread.mapq
            fields=str(alignedread).split('\t')
            if noMulti and MAPQ!=255:
                continue
            Pos=alignedread.pos
            NotOnly03=False
            splices=[]
            currentPos=Pos
            for (m,bp) in alignedread.cigar:
                if m == 4:
                    continue
                if m != 0 and m!= 3 and m != 4:
                    NotOnly03=True
                    continue
                if m == 0:
                    currentPos = currentPos + bp
                if m == 3:
                    intronLeft = currentPos
                    intronRight = currentPos + bp
                    currentPos = currentPos + bp
                    splices.append((intronLeft,intronRight))
            if NotOnly03:
                continue
            if alignedread.is_paired:
                matestart = alignedread.mpos
            else:
                matestart=0
            FLAGfields = FLAG(int(fields[1]))
            if 16 in FLAGfields:
                if alignedread.is_read2 and doStrandSR:
                    strand = '-'
                if alignedread.is_read1 and doStrandSR:
                    strand = '+'
                if alignedread.is_read2 and doStrandFR:
                    strand = '+'
                if alignedread.is_read1 and doStrandFR:
                    strand = '-'
            else:
                if alignedread.is_read2 and doStrandSR:
                    strand = '+'
                if alignedread.is_read1 and doStrandSR:
                    strand = '-'
                if alignedread.is_read2 and doStrandFR:
                    strand = '-'
                if alignedread.is_read1 and doStrandFR:
                    strand = '+'
            if doStranded:
                for (IntronLeft,IntronRight) in splices:
                    if JunctionDict.has_key((chr,IntronLeft-1,IntronRight,strand)):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,strand)].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
                    else:
                        JunctionDict[(chr,IntronLeft-1,IntronRight,strand)]={}
                        JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
            else:
                for (IntronLeft,IntronRight) in splices:
                    if JunctionDict.has_key((chr,IntronLeft-1,IntronRight,strand)):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,strand)].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
                    elif strand == '-' and JunctionDict.has_key((chr,IntronLeft-1,IntronRight,'+')):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,'+')].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'+')][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'+')][(Pos,matestart)]=1
                    elif strand == '+' and JunctionDict.has_key((chr,IntronLeft-1,IntronRight,'-')):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,'-')].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'-')][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'-')][(Pos,matestart)]=1
                    else:
                        JunctionDict[(chr,IntronLeft-1,IntronRight,'.')]={}
                        JunctionDict[(chr,IntronLeft-1,IntronRight,'.')][(Pos,matestart)]=1

    print 'Finished parsing BAM file'
    outfile = open(outputfilename, 'w')

    keys=JunctionDict.keys()
    keys.sort()
    for (chr,IntronLeft,IntronRight,strand) in keys:
        Nstaggered=len(JunctionDict[(chr,IntronLeft,IntronRight,strand)].keys())
        Ntotal=0
        for key in JunctionDict[(chr,IntronLeft,IntronRight,strand)].keys():
            Ntotal+=JunctionDict[(chr,IntronLeft,IntronRight,strand)][key]
        if Ntotal == 0:
            continue
        outline=chr+'\t'+str(IntronLeft)+'\t'+str(IntronRight) + '\t' + strand
        outline=outline+'\t'+str(Ntotal)+'\t'+str(Nstaggered)
        outfile.write(outline+'\n')

    outfile.close()

run()
