##################################
#                                #
# Last modified 08/17/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
from sets import Set

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList


def run():

    if len(sys.argv) < 6:
        print 'usage: python %s BAM chrom.sizes junctions.juncs GTF overhang_size outfile [-nomulti] [-stranded FRstrand|SRstrand]' % sys.argv[0]
        print '\tNote: The overhang size paramater refers to the minimum number of bases that a read has to span a splice site on each side to be considered; the suggested size is the same size used to call splice junctions'
        print '\tNote: The junctions.juncs file should be in the following format: <chr> tab <left> tab <right> <strand> <total_counts> tab <collapsed_counts>'
        print '\tNote: If the -stranded option is used, only reads compatible with the strand of transcription will be counted for each gene according to the protocol used: FRstrand - first read matches the strand of transcription; SRstrand - the second read matches the strand of transcription'
        print '\tNote: The script will count fragments, not reads'
        sys.exit(1)

    bam = sys.argv[1]
    junctions = sys.argv[3]
    GTF = sys.argv[4]
    OH = int(sys.argv[5])
    outputfilename = sys.argv[6]

    chromInfoList=[]
    linelist=open(sys.argv[2])
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    noMulti = False
    if '-nomulti' in sys.argv:
        noMulti = True

    doStranded = False
    if '-stranded' in sys.argv:
        doStranded = True

    JunctionsDict = {}
    linelist = open(junctions)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        strand = fields[3]
        left = int(fields[1])
        right = int(fields[2])
        TotalCounts = int(fields[4])
        StaggeredCounts = int(fields[5])
        if JunctionsDict.has_key(chr):
            pass
        else:
            JunctionsDict[chr] = {}
        JunctionDict[chr][(left,right,strand)]['spliced_counts'] = (TotalCounts,StaggeredCounts)

    i=0
    samfile = pysam.Samfile(bam, "rb" )
    for (chr,start,end) in chromInfoList:
        print chr,start,end
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        for alignedread in samfile.fetch(chr, start, end):
            i+=1
            if i % 5000000 == 0:
                print str(i/1000000) + 'M alignments processed'
            if len(alignedread.cigar) == 1:
                continue 
            MAPQ=alignedread.mapq
            fields=str(alignedread).split('\t')
            if noMulti and MAPQ!=255:
                continue
            Pos=alignedread.pos
            NotOnly03=False
            splices=[]
            currentPos=Pos
            for (m,bp) in alignedread.cigar:
                if m == 4:
                    continue
                if m != 0 and m!= 3 and m != 4:
                    NotOnly03=True
                    continue
                if m == 0:
                    currentPos = currentPos + bp
                if m == 3:
                    intronLeft = currentPos
                    intronRight = currentPos + bp
                    currentPos = currentPos + bp
                    splices.append((intronLeft,intronRight))
            if NotOnly03:
                continue
            if alignedread.is_paired:
                matestart = alignedread.mpos
            else:
                matestart=0
            FLAGfields = FLAG(int(fields[1]))
            if 16 in FLAGfields:
                if alignedread.is_read2 and doStrandSR:
                    strand = '-'
                if alignedread.is_read1 and doStrandSR:
                    strand = '+'
                if alignedread.is_read2 and doStrandFR:
                    strand = '+'
                if alignedread.is_read1 and doStrandFR:
                    strand = '-'
            else:
                if alignedread.is_read2 and doStrandSR:
                    strand = '+'
                if alignedread.is_read1 and doStrandSR:
                    strand = '-'
                if alignedread.is_read2 and doStrandFR:
                    strand = '-'
                if alignedread.is_read1 and doStrandFR:
                    strand = '+'
            if doStranded:
                for (IntronLeft,IntronRight) in splices:
                    if JunctionDict.has_key((chr,IntronLeft-1,IntronRight,strand)):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,strand)].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
                    else:
                        JunctionDict[(chr,IntronLeft-1,IntronRight,strand)]={}
                        JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
            else:
                for (IntronLeft,IntronRight) in splices:
                    if JunctionDict.has_key((chr,IntronLeft-1,IntronRight,strand)):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,strand)].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,strand)][(Pos,matestart)]=1
                    elif strand == '-' and JunctionDict.has_key((chr,IntronLeft-1,IntronRight,'+')):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,'+')].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'+')][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'+')][(Pos,matestart)]=1
                    elif strand == '+' and JunctionDict.has_key((chr,IntronLeft-1,IntronRight,'-')):
                        if JunctionDict[(chr,IntronLeft-1,IntronRight,'-')].has_key((Pos,matestart)):
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'-')][(Pos,matestart)]+=1
                        else:
                            JunctionDict[(chr,IntronLeft-1,IntronRight,'-')][(Pos,matestart)]=1
                    else:
                        JunctionDict[(chr,IntronLeft-1,IntronRight,'.')]={}
                        JunctionDict[(chr,IntronLeft-1,IntronRight,'.')][(Pos,matestart)]=1

    print 'Finished parsing BAM file'
    outfile = open(outputfilename, 'w')

    keys=JunctionDict.keys()
    keys.sort()
    for (chr,IntronLeft,IntronRight,strand) in keys:
        Nstaggered=len(JunctionDict[(chr,IntronLeft,IntronRight,strand)].keys())
        Ntotal=0
        for key in JunctionDict[(chr,IntronLeft,IntronRight,strand)].keys():
            Ntotal+=JunctionDict[(chr,IntronLeft,IntronRight,strand)][key]
        if Ntotal == 0:
            continue
        outline=chr+'\t'+str(IntronLeft)+'\t'+str(IntronRight) + '\t' + strand
        outline=outline+'\t'+str(Ntotal)+'\t'+str(Nstaggered)
        outfile.write(outline+'\n')

    outfile.close()

run()
