##################################
#                                #
# Last modified 09/05/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s bowtie1 bowtie2 radius minLength maxLength outfilename' % sys.argv[0]
        print '       The script will collapse reads and look for the distance to reads from the second set on the opposite strand, including multireads' 
        sys.exit(1)

    bowtie1 = sys.argv[1]
    bowtie2 = sys.argv[2]
    radius = int(sys.argv[3])
    minLength = int(sys.argv[4])
    maxLength = int(sys.argv[5])
    outfilename = sys.argv[6]


    outfile = open(outfilename, 'w')

    ReadDict1={}
    ReadDict2={}
    ReadDict2['+']={}
    ReadDict2['-']={}

    lineslist = open(bowtie2)
    for line in lineslist:
        fields=line.strip().split('\t')
        sequence=fields[4]
        if len(sequence) < minLength or len(sequence) > maxLength:
            continue
        chr=fields[2]
        start=int(fields[3])
        end=start+len(fields[4])
        strand=fields[1]
        if strand == '+':
            if ReadDict2['+'].has_key(chr):
                pass
            else:
                ReadDict2['+'][chr]={}
            ReadDict2['+'][chr][start]=0
        if strand == '-':
            if ReadDict2['-'].has_key(chr):
                pass
            else:
                ReadDict2['-'][chr]={}
            ReadDict2['-'][chr][end]=0

    DistanceDict={}
    for i in range(-radius,radius+1):
        DistanceDict[i]=0

    lineslist = open(bowtie2)
    for line in lineslist:
        fields=line.strip().split('\t')
        sequence=fields[4]
        if len(sequence) < minLength or len(sequence) > maxLength:
            continue
        chr=fields[2]
        start=int(fields[3])
        end=start+len(fields[4])
        strand=fields[1]
        if strand == '+':
            for pos in range(start-radius,start+radius):
                if ReadDict2['-'].has_key(chr) and ReadDict2['-'][chr].has_key(pos):
                    DistanceDict[pos-start]+=1
        if strand == '-':
            for pos in range(end-radius,end+radius):
                if ReadDict2['+'].has_key(chr) and ReadDict2['+'][chr].has_key(pos):
                    DistanceDict[end-pos]+=1

    outline = '#Distance\tCounts\n'
    outfile.write(outline)

    keys = DistanceDict.keys()
    keys.sort()
    for i in keys:
        outline = str(i) + '\t' + str(DistanceDict[i])
        outfile.write(outline + '\n')

    outfile.close()
        
run()

