##################################
#                                #
# Last modified 01/23/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import os
import string
import time
import math
import random
from sets import Set

def run():

    if len(sys.argv) < 14:
        print 'usage: python %s BAM regions.bed chrFieldID leftFieldID rightFieldID number_alignments_to_sample minLength maxLength overlap color1 color2 color3 color4 outfileprefix [-nomulti] [-noNH samtools] [-collapseDups]' % sys.argv[0]
        print '\tNote1: if you want to run the script over the whole genome, have the regions.bed file be the chrom.sizes file in BED format'
        print '\tNote2: the script will sample alignments, not reads!!!'
        print '\tcolor1: plus-strand reads in ping-pong pairs, RGB, comma-separated'
        print '\tcolor2: minus-strand reads in ping-pong pairs, RGB, comma-separated'
        print '\tcolor3: plus-strand reads NOT in ping-pong pairs, RGB, comma-separated'
        print '\tcolor4: minus-strand reads NOT in ping-pong pairs, RGB, comma-separated'
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}
    RNA = {'A':'U','U':'A','G':'C','C':'G','N':'N','a':'u','u':'a','g':'c','c':'g','n':'n','-':'-'}

    noMulti = False
    if '-nomulti' in sys.argv:
        noMulti = True

    BAM = sys.argv[1]
    bed = sys.argv[2]
    chrFielID = int(sys.argv[3])
    leftFielID = int(sys.argv[4])
    rightFielID = int(sys.argv[5])
    numReads = int(sys.argv[6])
    minLength = int(sys.argv[7])
    maxLength = int(sys.argv[8])
    Overlap = int(sys.argv[9])
    RGB1 = sys.argv[10]
    RGB2 = sys.argv[11]
    RGB3 = sys.argv[12]
    RGB4 = sys.argv[13]
    outfilename = sys.argv[14]

    doCollapseDups = False
    if '-collapseDups' in sys.argv:
        doCollapseDups = True

    samfile = pysam.Samfile(BAM, "rb" )
    try:
        print 'testing for NH tags presence'
        for alignedread in samfile.fetch():
            multiplicity = alignedread.opt('NH')
            break
    except:
        if '-noNH' in sys.argv:
            print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
            samtools = sys.argv[sys.argv.index('-noNH')+1]
            BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
            cmd = 'python ' + BAMpreporcessingScript + ' ' + BAM + ' ' + BAM + '.NH'
            os.system(cmd)
            cmd = 'rm ' + BAM
            os.system(cmd)
            cmd = 'mv ' + BAM + '.NH' + ' ' + BAM
            os.system(cmd)
            cmd = samtools + ' index ' + BAM
            os.system(cmd)
        else:
            print 'no NH: tags in BAM file, exiting'
            sys.exit(1)

    RegionDict = {}
    linelist = open(bed)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFielID]
        left = int(fields[leftFielID])
        right = int(fields[rightFielID])
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr] = []
        RegionDict[chr].append((left,right))

    samfile = pysam.Samfile(BAM, "rb" )
    ReadList = []
    for chr in RegionDict.keys():
        for (left,right) in RegionDict[chr]:
            for alignedread in samfile.fetch(chr, left, right):
                if noMulti and alignedread.opt('NH') > 1:
                    continue
                sequence = alignedread.seq
                if (len(sequence) < minLength) or (len(sequence) > maxLength):
                    continue
                pos = alignedread.pos
                s = '+'
                if alignedread.is_reverse:
                    pos = pos + len(sequence)
                    s = '-'
                ReadList.append((chr,pos,s,len(sequence)))

    if doCollapseDups:
        ReadList = list(Set(ReadList))

    outfile_plus = open(outfilename + '.plus.bed', 'w')
    outline = 'track type=bed name="' + outfilename + '.plus' + '" priority=0.010 visibility=full itemRgb=On'
    outfile_plus.write(outline + '\n')

    outfile_minus = open(outfilename + '.minus.bed', 'w')
    outline = 'track type=bed name="' + outfilename + '.minus' + '" priority=0.010 visibility=full itemRgb=On'
    outfile_minus.write(outline + '\n')

    sampledReadList = random.sample(ReadList,min(numReads,len(ReadList)))
    SampledReadDict = {}
    ToMatchReadDict = {}
    for read in sampledReadList:
        (chr,pos,s,readLen) = read
        SampledReadDict[read] = 1
        ToMatchReadDict[(chr,pos,s)] = 1
    for read in sampledReadList:
        InPingPongPairs = 0
        (chr,pos,s,readLen) = read
        if s == '+':
            pp = (chr,pos+Overlap,'-')
            if ToMatchReadDict.has_key(pp):
                RGB = RGB1
            else:
                RGB = RGB3
            left = pos
            right = pos + readLen
            outline = chr + '\t' + str(left) + '\t' + str(right) + '\t' + '+' + '\t' + str(1000) + '\t'+ s + '\t' + str(left) + '\t' + str(right) + '\t' + RGB
            outfile_plus.write(outline + '\n')
        if s == '-':
            pp = (chr,pos-Overlap,'+')
            if ToMatchReadDict.has_key(pp):
                RGB = RGB2
            else:
                RGB = RGB4
            right = pos
            left = pos - readLen
            outline = chr + '\t' + str(left) + '\t' + str(right) + '\t' + '-' + '\t' + str(1000) + '\t'+ s + '\t' + str(left) + '\t' + str(right) + '\t' + RGB
            outfile_minus.write(outline + '\n')

    outfile_plus.close()
    outfile_minus.close()
        
run()

