##################################
#                                #
# Last modified 01/23/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import os
import string
import time
import math
import random
from sets import Set

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s BAM regions.bed chrFieldID leftFieldID rightFieldID number_alignments_to_sample iterations minLength maxLength overlap outfilename [-nomulti] [-noNH samtools] [-collapseDups]' % sys.argv[0]
        print '\tNote1: if you want to run the script over the whole genome, have the regions.bed file be the chrom.sizes file in BED format'
        print '\tNote2: the script will sample alignments, not reads!!!'
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}
    RNA = {'A':'U','U':'A','G':'C','C':'G','N':'N','a':'u','u':'a','g':'c','c':'g','n':'n','-':'-'}

    noMulti = False
    if '-nomulti' in sys.argv:
        noMulti = True

    BAM = sys.argv[1]
    bed = sys.argv[2]
    chrFielID = int(sys.argv[3])
    leftFielID = int(sys.argv[4])
    rightFielID = int(sys.argv[5])
    numReads = int(sys.argv[6])
    iterations = int(sys.argv[7])
    minLength = int(sys.argv[8])
    maxLength = int(sys.argv[9])
    Overlap = int(sys.argv[10])
    outfilename = sys.argv[11]

    doCollapseDups = False
    if '-collapseDups' in sys.argv:
        doCollapseDups = True

    samfile = pysam.Samfile(BAM, "rb" )
    try:
        print 'testing for NH tags presence'
        for alignedread in samfile.fetch():
            multiplicity = alignedread.opt('NH')
            break
    except:
        if '-noNH' in sys.argv:
            print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
            samtools = sys.argv[sys.argv.index('-noNH')+1]
            BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
            cmd = 'python ' + BAMpreporcessingScript + ' ' + BAM + ' ' + BAM + '.NH'
            os.system(cmd)
            cmd = 'rm ' + BAM
            os.system(cmd)
            cmd = 'mv ' + BAM + '.NH' + ' ' + BAM
            os.system(cmd)
            cmd = samtools + ' index ' + BAM
            os.system(cmd)
        else:
            print 'no NH: tags in BAM file, exiting'
            sys.exit(1)

    RegionDict = {}
    linelist = open(bed)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFielID]
        left = int(fields[leftFielID])
        right = int(fields[rightFielID])
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr] = []
        RegionDict[chr].append((left,right))

    samfile = pysam.Samfile(BAM, "rb" )
    ReadList = []
    for chr in RegionDict.keys():
        for (left,right) in RegionDict[chr]:
            for alignedread in samfile.fetch(chr, left, right):
                if noMulti and alignedread.opt('NH') > 1:
                    continue
                sequence = alignedread.seq
                if (len(sequence) < minLength) or (len(sequence) > maxLength):
                    continue
                pos = alignedread.pos
                s = '+'
                if alignedread.is_reverse:
                    pos = pos + len(sequence)
                    s = '-'
                ReadList.append((chr,pos,s))

    if doCollapseDups:
        ReadList = list(Set(ReadList))

    outfile = open(outfilename, 'w')
    outline = '#Iteration\tTotalAlignments\tSampledAlignments\tSampledAlignmentsInPingPongPairs\tFraction'
    outfile.write(outline + '\n')

    for i in range(iterations):
        print i
        start = time.time()
        sampledReadList = random.sample(ReadList,min(numReads,len(ReadList)))
        SampledReadDict = {}
        for read in sampledReadList:
            SampledReadDict[read] = 1
        InPingPongPairs = 0.0
        for read in sampledReadList:
            (chr,pos,s) = read
            if s == '+':
                pp = (chr,pos+Overlap,'-')
                if SampledReadDict.has_key(pp):
                    InPingPongPairs += 1
            if s == '-':
                pp = (chr,pos-Overlap,'+')
                if SampledReadDict.has_key(pp):
                    InPingPongPairs += 1
        outline = str(i) + '\t' + str(len(ReadList)) + '\t' + str(len(sampledReadList)) + '\t' + str(InPingPongPairs) + '\t' + str(InPingPongPairs/(len(sampledReadList)))
        outfile.write(outline + '\n')
        end = time.time()
        print end - start

    outfile.close()
        
run()

