##################################
#                                #
# Last modified 2017/02/14       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import pysam
import string
from sets import Set
import os

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s RNA_bam1,RNA_bam2,...,RNA_bamN Input_bam1,Input_bam2,...,Input_bamN chrom.sizes outputfilename [-singleFieldCoords]' % sys.argv[0]
        print '\tNote: The BAM files have to be indexed'
        print '\tNote: Only unique alignments will be considered'
        print '\tNote: It is assumed that both reads in a pair are of the same length'
        sys.exit(1)

    STARRBAMFiles = sys.argv[1].split(',')
    ControlBAMFiles = sys.argv[2].split(',')
    chrominfo = sys.argv[3]
    outputfilename = sys.argv[4]

    doSFC = False
    if '-singleFieldCoords' in sys.argv:
        doSFC = True

    chromosomes = []
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromosomes.append((chr,start,end))

    IndexToBAMDict = {}    

    outfile = open(outputfilename,'w')
    if doSFC:
        outline = '#chr:start-end|strand'
    else:
        outline = '#chr\tstart\tend\tstrand'

    B = 0
    for BAM in STARRBAMFiles:
        B += 1
        IndexToBAMDict[B] = BAM
        outline = outline + '\t' + BAM
    for BAM in ControlBAMFiles:
        B += 1
        IndexToBAMDict[B] = BAM
        outline = outline + '\t' + BAM

    outfile.write(outline + '\n')

    for (chr,start,end) in chromosomes:
        print chr,start,end
        FragmentDict = {}
        for B in IndexToBAMDict:
            BAM = IndexToBAMDict[B]
            samfile = pysam.Samfile(BAM, "rb" )
            try:
                jj=0
                for alignedread in samfile.fetch(chr, start, end):
                   jj+=1
                   if jj==1:
                       break
            except:
                print 'problem with region:', chr, start, end, 'in BAM file', BAM, 'skipping'
                continue
            for alignedread in samfile.fetch(chr, start, end):
#                i+=1
#                if i % 5000000 == 0:
#                    print str(i/1000000) + 'M alignments processed in', BAM, chr,start,alignedread.pos,end
                multiplicity = alignedread.opt('NH')
                if multiplicity > 1:
                    continue
                fields=str(alignedread).split('\t')
                ID=fields[0]
                if alignedread.is_paired:
                    pass
                else:
                    continue
                if alignedread.is_read1:
                    pass
                else:
                    continue
                if alignedread.is_reverse:
                    strand = '-'
                    end2Pos = alignedread.pos + len(alignedread.seq)
                    end1Pos = alignedread.next_reference_start
                else:
                    strand = '+'
                    end1Pos = alignedread.pos
                    end2Pos = alignedread.next_reference_start + len(alignedread.seq)
                fragment = (chr,end1Pos,end2Pos,strand)
                if FragmentDict.has_key(fragment):
                    pass
                else:
                    FragmentDict[fragment] = {}
                if FragmentDict[fragment].has_key(B):
                    pass
                else:
                    FragmentDict[fragment][B] = 0
                FragmentDict[fragment][B] += 1
        fragments = FragmentDict.keys()
        fragments.sort()
        for (chr,read1Pos,read2End,strand) in fragments:
            if doSFC:
                outline = chr + ':' + str(read1Pos) + '-' + str(read2End) + '|' + strand
            else:
                outline = chr + '\t' + str(read1Pos) + '\t' + str(read2End) + '\t' + strand
            for B in IndexToBAMDict:
                if FragmentDict[(chr,read1Pos,read2End,strand)].has_key(B):
                    outline = outline + '\t' + str(FragmentDict[(chr,read1Pos,read2End,strand)][B])
                else:
                    outline = outline + '\t' + '0'
            outfile.write(outline + '\n')

run()
