##################################
#                                #
# Last modified 03/19/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s <SAM filename> <output filename> [-TopHat junctions.bed] [-junctions junctions.junctions (chr left right strand)] [-nomulti] [-bam chrom.sizes] [-fragments]' % sys.argv[0]
        print '   BAM file support in progress'
        print '   the -TopHat and -junctions options can not be used in the same time'

        sys.exit(1)

    inputfilename = sys.argv[1]
    outputfilename = sys.argv[2]

    doFragments=False
    if '-fragments' in sys.argv:
        doFragments=True
        print 'will treat reads as paired and return fragments'

    doCompareToTopHaT=False
    JunctionsFormat=False
    if '-TopHat' in sys.argv:
        doCompareToTopHaT=True
        print 'will get strand from TopHat junctions.bed file'
        TopHatFilename=sys.argv[sys.argv.index('-TopHat')+1]

    if '-junctions' in sys.argv:
        doCompareToTopHaT=True
        JunctionsFormat=True
        print 'will get strand from provided junctions file'
        TopHatFilename=sys.argv[sys.argv.index('-junctions')+1]

    doBAM=False
    if '-bam' in sys.argv:
        doBAM=True
        chrom.info=sys.argv[sys.argv.index('-bam')+1]
        chromInfoList=[]
        linelist=open(chrom.info)
        for line in linelist:
            fields=line.strip().split('\t')
            chr=fields[0]
            start=0
            end=int(fields[1])
            chromInfoList.append((chr,start,end))

    noMulti=False
    if '-nomulti' in sys.argv:
        noMulti=True

    outputfilename = sys.argv[2]

    if doCompareToTopHaT:
        lineslist = open(TopHatFilename)
        TopHatJunctionDict={}
        if JunctionsFormat:
            for line in lineslist:
                if line.startswith('#'):
                    continue
                fields=line.strip().split('\t')
                chr=fields[0]
                start=int(fields[1])
                stop=int(fields[2])
                orientation=fields[3]
                TopHatJunctionDict[(chr,start,stop)]=(orientation)
        else:
            for line in lineslist:
                fields=line.strip().split('\t')
                if len(fields)<11:
                    continue
                chr=fields[0]
                leftShift=int(fields[10].split(',')[0])+1
                rightShift=int(fields[10].split(',')[1])-1
                start=int(fields[1])+leftShift
                stop=int(fields[2])-rightShift
                orientation=fields[5]
                TopHatJunctionDict[(chr,start,stop)]=(orientation,int(fields[4]))

    if doBAM:
        JunctionDict={}
        for (chr,start,end) in chromInfoList:
            print chr,start,end
            for alignedread in inputfilename.fetch(chr, start, end):
                fields=str(alignedread).split('\t')
    else:
        lineslist = open(inputfilename)
        i=0
        JunctionDict={}
        for line in lineslist:
            i+=1
            if i % 1000000 == 0:
                print i, 'lines processed'
            fields = line.strip().split('\t')
            if len(fields) < 6:
                continue
            if 'N' not in fields[5]:
                continue
            MAPQ=fields[4]
            if noMulti and MAPQ!='255':
                continue
            chr=fields[2]
            Pos=int(fields[3])
            juncList=fields[5].split('N')
            try:
                for splice in juncList[0:-1]:
                    if 'D' in splice:
                        spliceLeft=int(splice.split('M')[0]) + int(splice.split('M')[1].split('D')[0]) + int(splice.split('M')[1].split('D')[1])
                        IntronSize=int(splice.split('M')[2])
                    else:
                        spliceLeft=int(splice.split('M')[0])
                        IntronSize=int(splice.split('M')[1])
                    IntronLeft=Pos+spliceLeft
                    IntronRight=Pos+spliceLeft+IntronSize
                    if doFragments:
                        matestart=fields[7]
                        spliceLeft=((spliceLeft,matestart))
                    if JunctionDict.has_key((chr,IntronLeft,IntronRight)):
                        if JunctionDict[(chr,IntronLeft,IntronRight)].has_key(spliceLeft):
                            JunctionDict[(chr,IntronLeft,IntronRight)][spliceLeft]+=1
                        else:
                            JunctionDict[(chr,IntronLeft,IntronRight)][spliceLeft]=1
                    else:
                        JunctionDict[(chr,IntronLeft,IntronRight)]={}
                        JunctionDict[(chr,IntronLeft,IntronRight)][spliceLeft]=1
                    Pos=IntronRight
            except:
                print fields
                continue

    print 'Finished parsing SAM file'
    outfile = open(outputfilename, 'w')

    keys=JunctionDict.keys()
    keys.sort()
    for (chr,IntronLeft,IntronRight) in keys:
        Nstaggered=len(JunctionDict[(chr,IntronLeft,IntronRight)].keys())
        Ntotal=0
        for key in JunctionDict[(chr,IntronLeft,IntronRight)].keys():
            Ntotal+=JunctionDict[(chr,IntronLeft,IntronRight)][key]
        outline=chr+'\t'+str(IntronLeft)+'\t'+str(IntronRight)
        if doCompareToTopHaT:
            if TopHatJunctionDict.has_key((chr,IntronLeft,IntronRight)):
                outline=outline+'\t'+str(TopHatJunctionDict[(chr,IntronLeft,IntronRight)][0])
                del TopHatJunctionDict[(chr,IntronLeft,IntronRight)]
            else:
                if Ntotal > 50:
                    print chr,IntronLeft,IntronRight, Ntotal, Nstaggered
                continue
        outline=outline+'\t'+str(Ntotal)+'\t'+str(Nstaggered)
        outfile.write(outline+'\n')

    outfile.close()

run()
