##################################
#                                #
# Last modified 10/08/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam


def run():

    if len(sys.argv) < 4:
        print 'usage: python %s mappability_BAM_filename chrom.sizes outputfile_prefix' % sys.argv[0]
        print '\tNote1: the script assumes that the mappability BAM file has been generated allowing for uniquely mapping reads only'
        print '\tNote2: the script will produce three tracks:'
        print "\t\t\t 1) a plus strand one with scores of 1 for bases for which a read in the 5' direction is mappable"
        print "\t\t\t 2) a minus strand one with scores of 1 for bases for which a read in the 3' direction is mappable"
        print "\t\t\t 3) a combined one with scores of 2 for bases for which a read in the both directions is mappable and score of 1 if only one direction is mappable"
        sys.exit(1)
    
    BAM = sys.argv[1]
    chrominfo=sys.argv[2]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
    outfile_prefix = sys.argv[3]

    samfile = pysam.Samfile(BAM, "rb" )

    outfilePlus = open(outfile_prefix + '.plus.wig', 'w')
    outfileMinus = open(outfile_prefix + '.minus.wig', 'w')
    outfileCombined = open(outfile_prefix + '.combined.wig', 'w')

    RN=0
    for (chr,start,end) in chromInfoList:
        coverageDictPlus={}
        coverageDictMinus={}
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        for alignedread in samfile.fetch(chr, start, end):
            RN+=1
            if RN % 5000000 == 0:
                print 'counting total number of reads', str(RN/1000000) + 'M alignments processed', chr, pos, end
            length = len(alignedread.seq)
            pos=alignedread.pos
            coverageDictPlus[pos]=1
            coverageDictMinus[pos+length]=1
        initialPlus = ''
        initialMinus = ''
        initial = ''
        currentScore = 0
        currentScorePlus = 0
        currentScoreMinus = 0
        for i in range(end):
            score = 0
            scorePlus = 0
            scoreMinus = 0
            if coverageDictPlus.has_key(i):
                score+=1
                scorePlus+=1
            if coverageDictMinus.has_key(i):
                score+=1
                scoreMinus+=1
            if initial == '' and currentScore != score:
                initial = i
                currentScore = score
            if currentScore != score:
                outline = chr + '\t' + str(initial) + '\t' + str(i) + '\t' + str(currentScore)
                outfileCombined.write(outline + '\n')
                initial = i
                currentScore = score
            if initialPlus == '' and currentScorePlus != scorePlus:
                initialPlus = i
                currentScorePlus = scorePlus
            if currentScorePlus != scorePlus:
                outline = chr + '\t' + str(initialPlus) + '\t' + str(i) + '\t' + str(currentScorePlus)
                outfilePlus.write(outline + '\n')
                initialPlus = i
                currentScorePlus = scorePlus
            if initialMinus == '' and currentScoreMinus != scoreMinus:
                initialMinus = i
                currentScoreMinus = scoreMinus
            if currentScoreMinus != scoreMinus:
                outline = chr + '\t' + str(initialMinus) + '\t' + str(i) + '\t' + str(currentScoreMinus)
                outfileMinus.write(outline + '\n')
                initialMinus = i
                currentScoreMinus = scoreMinus

    outfilePlus.close()
    outfileMinus.close()
    outfileCombined.close()
            
run()
