##################################
#                                #
# Last modified 02/09/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s junctions junctionsfilewithstrand outfilename' % sys.argv[0]
        print '      assumed formats:'
        print '      junctions: chr left right strand total_counts staggered_counts'
        print '      junctions with strand: chr left right strand'
        sys.exit(1)

    inputfilename = sys.argv[1]
    stranded = sys.argv[2]
    outfilename = sys.argv[3]

    StrandDict={}
    listoflines = open(stranded)
    i=0
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        if len(fields) < 5 and chr== 'chrM':
            StrandDict[(chr,left,right)]=('+/-','ND','ND','ND')
            continue
        type1=fields[4]
        type2=fields[5]
        type3=fields[6]
        if StrandDict.has_key((chr,left,right)):
            i+=1
            oldtype = StrandDict[(chr,left,right)][2]
            print 'same coordinates, both strand case #', i, chr,left,right, StrandDict[(chr,left,right)], fields
            if oldtype == 'GT|AG' and type2 == 'GT|AG':
                StrandDict[(chr,left,right)][0]='+/-'
            elif oldtype != 'GT|AG' and type2 == 'GT|AG':
                StrandDict[(chr,left,right)]=(strand,type1,type2,type3)
            elif oldtype == 'GT|AG' and type2 != 'GT|AG':
                pass
            else:
                if oldtype == 'GC|AG':
                    if type2 == 'GC|AG':
                        StrandDict[(chr,left,right)]=list(StrandDict[(chr,left,right)])
                        StrandDict[(chr,left,right)][0]='+/-'
                        StrandDict[(chr,left,right)]=tuple(StrandDict[(chr,left,right)])
                    else:
                        pass
                elif oldtype == 'AT|AC':
                    if type2 == 'GC|AG':
                        StrandDict[(chr,left,right)]=(strand,type1,type2,type3)
                    elif type2 == 'AT|AC':
                        StrandDict[(chr,left,right)]=list(StrandDict[(chr,left,right)])
                        StrandDict[(chr,left,right)][0]='+/-'
                        StrandDict[(chr,left,right)]=tuple(StrandDict[(chr,left,right)])
                    else:
                        pass
                else:
                    if type2 == 'GC|AG' or type2 == 'AT|AC':
                        StrandDict[(chr,left,right)]=(strand,type1,type2,type3)
                    else:
                        StrandDict[(chr,left,right)]=list(StrandDict[(chr,left,right)])
                        StrandDict[(chr,left,right)][0]='+/-'
                        StrandDict[(chr,left,right)]=tuple(StrandDict[(chr,left,right)])
            print 'picked', StrandDict[(chr,left,right)]
            continue
        else:
            StrandDict[(chr,left,right)]=(strand,type1,type2,type3)
    
    outfile = open(outfilename, 'w')

    listoflines = open(inputfilename)
    notfound=0
    for line in listoflines:
        fields=line.strip().split('\t')
        chr=fields[0]
        left=int(fields[1])-2
        right=int(fields[2])-1
        total=fields[3]
        staggered=fields[4]
        try:
            strand=StrandDict[(chr,left,right)][0]
        except:
            notfound+=1
            print chr,left,right,total,staggered, 'not found in master juntions set #', notfound, 'skipping'
            continue
        outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+total+'\t'+staggered
        outfile.write(outline+'\n')

    outfile.close()

run()

