##################################
#                                #
# Last modified 02/05/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s <known_junctions> <novel junctions> outfilename' % sys.argv[0]
        print '	junction file format:'
        print '      chr1    17363   17601   -/+'
        sys.exit(1)

    known = sys.argv[1]
    novel = sys.argv[2]
    outfilename = sys.argv[3]

    knownJunctionsDict={}
    lineslist=open(known)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        strand = fields[3]
        left = int(fields[1])
        right = int(fields[2])
        if knownJunctionsDict.has_key(chr):
            pass
        else:
            knownJunctionsDict[chr]={}
            print 'known', chr
        if knownJunctionsDict[chr].has_key(strand):
            pass
        else:
            knownJunctionsDict[chr][strand]=[]
        knownJunctionsDict[chr][strand].append((left,right))

    outfile=open(outfilename,'w')

    novelJunctionsDict={}
    lineslist=open(novel)
    for line in lineslist:
        if line.startswith('#'):
            outline = line.strip() + 'Nearest_5p_exon_boundary_distance\tNearest_3p_exon_boundary_distance\n'
            outfile.write(outline)
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        strand = fields[3]
        left = int(fields[1])
        right = int(fields[2])
        if novelJunctionsDict.has_key(chr):
            pass
        else:
            novelJunctionsDict[chr]={}
        novelJunctionsDict[chr][(left,right,strand)] = line.strip()

    for chr in novelJunctionsDict.keys():
        print chr
        for strand in knownJunctionsDict[chr].keys():
            knownJunctionsDict[chr][strand].sort()
        keys = novelJunctionsDict[chr].keys()
        keys.sort()
        for (left,right,strand) in keys:
            distanceleft = 1000000000000000000
            distanceright = 1000000000000000000
            for (knownleft,knownright) in knownJunctionsDict[chr][strand]:
                if math.fabs(left - knownleft) < math.fabs(distanceleft):
                     distanceleft = left - knownleft
                if math.fabs(right - knownright) < math.fabs(distanceright):
                     distanceright = right - knownright
                if knownleft > right and knownleft - right > max(distanceleft,distanceright):
                     if strand == '+':
                         distance5 = distanceleft
                         distance3 = distanceright
                     if strand == '-':
                         distance5 = 0 - distanceright
                         distance3 = 0 - distanceleft
                     outline = novelJunctionsDict[chr][(left,right,strand)] + '\t' + str(distance5) +  '\t' + str(distance3) 
                     outfile.write(outline + '\n')
                     break

    outfile.close()
        
run()

