##################################
#                                #
# Last modified 02/05/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s BAM chrom.sizes outfilename [-chr chrN1[,....,chrNN]] [-withinRegions regions_file chrFieldID leftFieldID rightFieldID]' % sys.argv[0]
        print '       Note: the script assumes no duplicate readIDs'
        print '       -withinRegions option does not work together with the -ignoreChromSizes option'
        print '       the -chr option will output all alignments on the set of chromosomes indicate and will override the -withinRegions option'
        sys.exit(1)

    BAM = sys.argv[1]
    outputfilename = sys.argv[3]
    chrominfo=sys.argv[2]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    doChr=False
    if '-chr' in sys.argv:
        doChr=True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''
        print 'will output all reads aligning to the following chromosomes:'
        print WantedChrDict.keys()

    doRegions=False
    if '-withinRegions' in sys.argv:
        doRegions=True
        regionDict={}
        print 'will only output reads with alignments within regions defined in', sys.argv[sys.argv.index('-withinRegions')+1]
        linelist = open(sys.argv[sys.argv.index('-withinRegions')+1])
        chrFieldID = int(sys.argv[sys.argv.index('-withinRegions')+2])
        leftFieldID = int(sys.argv[sys.argv.index('-withinRegions')+3])
        rightFieldID = int(sys.argv[sys.argv.index('-withinRegions')+4])
        T=0
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[chrFieldID]
            if regionDict.has_key(chr):
                pass
            else:
                regionDict[chr]={}
            left = int(fields[leftFieldID])
            right = int(fields[rightFieldID])
            for i in range(left,right):
                regionDict[chr][i]=0
                T+=1
        print 'found', T, 'nucleotides in wanted regions'

    print 'will write reads into: ', outputfilename

    i=0
    samfile = pysam.Samfile(BAM, "rb" )
    outfile = pysam.Samfile(outputfilename, "wb", template=samfile)
    if doChr:
        for (chr,start,end) in chromInfoList:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
            try:
                for alignedread in samfile.fetch(chr, 0, 100):
                    a='b'
            except:
                continue
            for alignedread in samfile.fetch(chr, start, end):
                i+=1
                if i % 5000000 == 0:
                    print str(i/1000000) + 'M alignments processed processed', chr,start,alignedread.pos,end
                fields=str(alignedread).split('\t')
                FLAGfields = FLAG(int(fields[1]))
                if 16 in FLAGfields:
                    strand = '-'
                else:
                    strand = '+'
                newAlignedRead = alignedread
                newAlignedRead.tags = newAlignedRead.tags + [("XS",strand)]
                outfile.write(newAlignedRead)
    else:
        for (chr,start,end) in chromInfoList:
            try:
                for alignedread in samfile.fetch(chr, 0, 100):
                    a='b'
            except:
                continue
            for alignedread in samfile.fetch(chr, start, end):
                i+=1
                if i % 5000000 == 0:
                    print str(i/1000000) + 'M alignments processed processed', chr,start,alignedread.pos,end
                fields=str(alignedread).split('\t')
                if doRegions:
                    if regionDict.has_key(chr) and regionDict[chr].has_key(alignedread.pos):
                        pass
                    else:
                        continue
                FLAGfields = FLAG(int(fields[1]))
                if 16 in FLAGfields:
                    strand = '-'
                else:
                    strand = '+'
                newAlignedRead = alignedread
                newAlignedRead.tags = newAlignedRead.tags + [("XS",strand)]
                outfile.write(newAlignedRead)

    outfile.close()

run()

