##################################
#                                #
# Last modified 11/04/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s junctions outfilename [-minReads threshold fields]' % sys.argv[0]
        print '\tExpected format: chr1\t10000695\t10002879\t-\tTA\tGC|AG\tGA'
        print '\tThe script will output all pairs of overlapping cannonical (GT|AG) and noncannonical (GC|AG or AT|AC) junctions on the same strand'
        print '\tUse the -minReads option with a combination of comma separated and from:to (including) fields'
        sys.exit(1)

    junctions = sys.argv[1]
    outfilename = sys.argv[2]

    doMinReads = False
    if '-minReads' in sys.argv:
        doMinReads = True
        threshold = int(sys.argv[sys.argv.index('-minReads') + 1])
        fields = sys.argv[sys.argv.index('-minReads') + 2]
        readFields = []
        for f in fields.split(','):
            if ':' in f:
                f1 = int(f.split(':')[0])
                f2 = int(f.split(':')[1])
                for i in range(f1,f2+1):
                    readFields.append(i)
            else:
                readFields.append(int(f))
        print 'will require at least', threshold, 'in fields', readFields

    C=0
    NC=0
    i=0

    JunctionsDictCannonical={}
    JunctionsDictNonCannonical={}
    lineslist  = open(junctions)
    for line in lineslist:
        i+=1
        if i % 10000 == 0:
            print i, 'lines processed'
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        type = fields[5]
        if doMinReads:
            reads = 0 
            for ID in readFields:
                reads += int(fields[ID])
            if reads < threshold:
                continue
        if JunctionsDictNonCannonical.has_key(chr):
            pass
        else:
            JunctionsDictNonCannonical[chr]={}
            JunctionsDictCannonical[chr]={}
        if type == 'GT|AG':
            C+=1
            JunctionsDictCannonical[chr][(left,right,strand)]=type
        elif type == 'GC|AG' or type == 'AT|AC':
            NC+=1
            JunctionsDictNonCannonical[chr][(left,right,strand)]=type
        else:
            continue

    print 'found', C, 'cannonical junctions'
    print 'found', NC, 'noncannonical junctions'

    outfile = open(outfilename, 'w')
    outline = '#chr\tstran\tcannonical_left\tcannonical_right\tcannonical_type\tnoncannonical_left\tnoncannonical_right\tnoncannonical_type\t'
    outfile.write(outline + '\n')

    found = 0
    for chr in JunctionsDictCannonical.keys():
        print found
        cannonical = JunctionsDictCannonical[chr].keys()
        noncannonical = JunctionsDictNonCannonical[chr].keys()
        cannonical.sort()
        noncannonical.sort()
        print chr
        for (leftc,rightc,strandc) in cannonical:
            for (leftnc,rightnc,strandnc) in noncannonical:
                if leftnc > rightc:
                    break
                if strandc != strandnc:
                    continue
                if (rightnc > leftc and rightnc < rightc) or (leftnc > leftc and leftnc < rightc) or (rightc > leftnc and rightc < rightnc) or (leftc > leftnc and leftc < rightnc):
                    outline = chr + '\t' + strand + '\t' + str(leftc) + '\t' + str(rightc) + '\t' + JunctionsDictCannonical[chr][(leftc,rightc,strandc)] + '\t' + str(leftnc) + '\t' + str(rightnc) + '\t' + JunctionsDictNonCannonical[chr][(leftnc,rightnc,strandnc)]
                    found+=1
                    outfile.write(outline + '\n')
    outfile.close()
        
run()

