##################################
#                                #
# Last modified 10/26/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s  knownJunctions junctions outputfile-prefix -nostrand -exactmatch' % sys.argv[0]
        print '	splices list file format: chr <tab> left <tab> right <tab> +/-' 
        print '	junctions file format: chr	left	right	orientation	total_reads	staggered_reads' 
        print '	Note: by default, the script will try all combinations of start and stop positions within 1bp of the listed ones' 
        print '            use the -exactmatch option for exact match' 
        sys.exit(1)
    
    splices = sys.argv[1]
    TopHat = sys.argv[2]
    Knownoutfilename = sys.argv[3]+'-known.txt'
    Noveloutfilename = sys.argv[3]+'-novel.txt'
    doStrand=True
    if '-nostrand' in sys.argv:
        doStrand=False
    doExact=False
    if '-exactmatch' in sys.argv:
        doExact=False

    outfileKnown = open(Knownoutfilename, 'w')
    outfileNovel = open(Noveloutfilename, 'w')

    linelist  = open(splices)
    splicesDict={}
    for line in linelist:
        fields=line.strip().split('\t')
        if doStrand:
            splice=fields[0]+':'+fields[1]+'-'+fields[2]+fields[3]
        else:
            splice=fields[0]+':'+fields[1]+'-'+fields[2]
        splicesDict[splice]=''

    linelist  = open(TopHat)
    for line in linelist:
        if line.startswith('#'):
            outfileKnown.write(line)
            outfileNovel.write(line)
            continue
        if 'track' in line:
            continue
        fields=line.strip().split('\t')
        try:
            start=int(fields[1])
            stop=int(fields[2]) 
        except:
            print 'skipping', line
            continue
        Found=False
        if doExact:
            if doStrand:
                splice=fields[0]+':'+str(start)+'-'+str(stop)+fields[3]
            else:
                splice=fields[0]+':'+str(start)+'-'+str(stop)
            if splicesDict.has_key(splice):
                Found=True
        else:
            if doStrand:
                splice1=fields[0]+':'+str(start)+'-'+str(stop)+fields[3]
                splice2=fields[0]+':'+str(start)+'-'+str(stop-1)+fields[3]
                splice3=fields[0]+':'+str(start)+'-'+str(stop+1)+fields[3]
                splice4=fields[0]+':'+str(start+1)+'-'+str(stop)+fields[3]
                splice5=fields[0]+':'+str(start+1)+'-'+str(stop+1)+fields[3]
                splice6=fields[0]+':'+str(start+1)+'-'+str(stop-1)+fields[3]
                splice7=fields[0]+':'+str(start-1)+'-'+str(stop)+fields[3]
                splice8=fields[0]+':'+str(start-1)+'-'+str(stop+1)+fields[3]
                splice9=fields[0]+':'+str(start-1)+'-'+str(stop-1)+fields[3]
                splices=[splice1,splice2,splice3,splice4,splice5,splice6,splice7,splice8,splice9]
            else:
                splice1=fields[0]+':'+str(start)+'-'+str(stop)
                splice2=fields[0]+':'+str(start)+'-'+str(stop-1)
                splice3=fields[0]+':'+str(start)+'-'+str(stop+1)
                splice4=fields[0]+':'+str(start+1)+'-'+str(stop)
                splice5=fields[0]+':'+str(start+1)+'-'+str(stop+1)
                splice6=fields[0]+':'+str(start+1)+'-'+str(stop-1)
                splice7=fields[0]+':'+str(start-1)+'-'+str(stop)
                splice8=fields[0]+':'+str(start-1)+'-'+str(stop+1)
                splice9=fields[0]+':'+str(start-1)+'-'+str(stop-1)
                splices=[splice1,splice2,splice3,splice4,splice5,splice6,splice7,splice8,splice9]
            for splice in splices:
                if splicesDict.has_key(splice):
                     Found=True
        if Found:
            outfileKnown.write(line)
        else:
            outfileNovel.write(line)
    outfileKnown.close()
    outfileNovel.close()

run()
