##################################
#                                #
# Last modified 02/25/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s input_file outfilename' % sys.argv[0]
        print '	the input file is the output of classifyJunctions.py and looks like this:'
        print '      chr1    17363   17601   -       AT      GG|AG   AA      known   known exon to known exon        WASH5P  WASH5P'
        print '      the script will only look at junctions of the _unknown internal exon_ class'
        sys.exit(1)

    input = sys.argv[1]
    outfilename = sys.argv[2]

    knownJunctionsDict={}
    novelJunctionsList=[]

    lineslist=open(input)
    for line in lineslist:
        fields=line.strip().split('\t')
        if fields[7]=='known' and fields[8]=='known exon to known exon':
            if ',' in fields[9]:
                geneNames=fields[9].split(',')
            else:
                geneNames=[]
                geneNames.append(fields[9])
            for geneName in geneNames:
                if knownJunctionsDict.has_key(geneName):
                    pass
                else:
                    knownJunctionsDict[geneName]=[]
                knownJunctionsDict[geneName].append((fields[0],int(fields[1]),int(fields[2]),fields[3]))
        if fields[7]=='novel' and 'unknown internal exon' in fields[8]:
            novelJunctionsList.append(fields)

    outfile=open(outfilename,'w')

    i=0
    DS=0
    for exonfields in novelJunctionsList:
        i+=1
        if i % 10000 == 0:
            print i, 'processed'
        if 'different genes' in exonfields[8]:
            continue
        if ',' in exonfields[9] or ',' in exonfields[10]:
            gene1=exonfields[9].split(',')
            gene2=exonfields[10].split(',')
            for geneX in gene1:
                if geneX in gene2:
                    gene=geneX
                    break
        else:
            gene1=exonfields[9]
            gene2=exonfields[10]
            if gene1==gene2:
                gene=gene1
        if gene1!=gene2:
#            print 'different genes, skipping', exonfields
            continue
        if knownJunctionsDict.has_key(gene):
            pass
        else:
            continue
        left=int(exonfields[1])
        right=int(exonfields[2])
        strand=exonfields[3]
        differentStrand=False
        if strand == '+':
            end5=left
            end3=right
            closest3=''
            closest5=''
            distance3=1000000000000000
            distance5=1000000000000000
            for junction in knownJunctionsDict[gene]:
                knownleft=int(junction[1])
                knownright=int(junction[2])
                if junction[3]!=strand:
                    differentStrand=True
                if math.fabs(knownleft-end5) < math.fabs(distance5):
                    distance5=knownleft-end5
                    closest5=junction
                if math.fabs(knownright-end3) < math.fabs(distance3):
                    distance3=knownright-end3
                    closest3=junction
        if strand == '-':
            end3=left
            end5=right
            closest3=''
            closest5=''
            distance3=1000000000000000
            distance5=1000000000000000
            for junction in knownJunctionsDict[gene]:
                knownleft=int(junction[2])
                knownright=int(junction[1])
                if junction[3]!=strand:
                    differentStrand=True
                if math.fabs(end3-knownleft) < math.fabs(distance3):
                    distance3=end3-knownleft
                    closest3=junction
                if math.fabs(end5-knownright) < math.fabs(distance5):
                     distance5=end5-knownright
                     closest5=junction
        outline=''
        for field in exonfields:
            outline=outline+field+'\t'
        if differentStrand:
#            print 'different strand in known gene', junction, exonfields, 
            DS+=1
            outline=outline+'DS\tDS\tDS\tDS\tDS\tDS\tDS\tDS\tDS'
        else:
            outline=outline+closest5[0]+'\t'
            outline=outline+str(closest5[1])+'\t'
            outline=outline+str(closest5[2])+'\t'
            outline=outline+closest5[3]+'\t'
            outline=outline+str(distance5)+'\t'
            outline=outline+closest3[0]+'\t'
            outline=outline+str(closest3[1])+'\t'
            outline=outline+str(closest3[2])+'\t'
            outline=outline+closest3[3]+'\t'
            outline=outline+str(distance3)
        outfile.write(outline+'\n')

    print 'junctions on a different strand from the known gene:', DS

    outfile.close()
        
run()

