##################################
#                                #
# Last modified 11/19/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s known.juncs juncs.table fieldID(s) outfile' % sys.argv[0]
        print '\tknown.juncs: known junctions (output of gtf-to-juncs on annotation; chr <tab> left <tab> right <tab> strand'
        print '\tjuncs.table: read counts for junctions (need not be exactly the same as the ones in known.juncs, the script will combined them and list which ones are known and which ones are novel; '
        print '\tjuncs.table format: chr <tab> left <tab> right <tab> strand <tab> counts1 <tab> counts2 <tab> .... countsN'
        print '\tformat of field IDs: either comma separated or start-end (including end) (0-based)'
        sys.exit(1)

    known = sys.argv[1]
    detected = sys.argv[2]
    fieldIDs = sys.argv[3]
    outfile = open(sys.argv[4], 'w')

    IDfields=[]
    if '-' in sys.argv[3]:
        fields1=int(sys.argv[3].split('-')[0])
        fields2=int(sys.argv[3].split('-')[1])
        for ID in range(fields1,fields2+1):
            IDfields.append(ID)
    else:
        fields = sys.argv[3].split(',')
        for ID in fields:
            IDfields.append(int(ID))

    JunctionDictD = {}
    JunctionDictA = {}

    lineslist  = open(known)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        if strand == '+':
            D = left
            A = right
        if strand == '-':
            A = left
            D = right
        if JunctionDictD.has_key((chr,D,strand)):
            pass
        else:
            JunctionDictD[(chr,D,strand)]={}
        JunctionDictD[(chr,D,strand)][A]={}
        for ID in IDfields:
            JunctionDictD[(chr,D,strand)][A][ID] = 0
        JunctionDictD[(chr,D,strand)][A]['known'] = 0
        if JunctionDictA.has_key((chr,A,strand)):
            pass
        else:
            JunctionDictA[(chr,A,strand)]={}
        JunctionDictA[(chr,A,strand)][D]={}
        for ID in IDfields:
            JunctionDictA[(chr,A,strand)][D][ID] = 0
        JunctionDictA[(chr,A,strand)][D]['known'] = 0

    print 'finished inputting known junctions'        

#    p1 = 0
#    p2 = 0

    lineslist  = open(detected)
    for line in lineslist:
        fields=line.strip().split('\t')
        if line[0]=='#':
            outline = 'chr\tleft\tright\tstrand\tknown/novel\tdonor/acceptor'
            for ID in IDfields:
                outline = outline + '\t' + fields[ID] + '_counts'
                outline = outline + '\t' + fields[ID] + '_total_counts'
                outline = outline + '\t' + fields[ID] + '_psi'
            outfile.write(outline + '\n')
            continue
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        if strand == '.':
#            print strand, p1,
#            p1 += 1
            if JunctionDictD.has_key((chr,left,'+')):
                if JunctionDictD[(chr,left,'+')].has_key(right):
                    strand = '+'
#                    p2 += 1
            if JunctionDictD.has_key((chr,right,'-')):
                if JunctionDictD[(chr,right,'-')].has_key(left):
                    strand = '-'
#                    p2 += 1
#            print strand, p2
        if strand == '+':
            D = left
            A = right
        if strand == '-':
            A = left
            D = right
        if JunctionDictD.has_key((chr,D,strand)):
            pass
        else:
            JunctionDictD[(chr,D,strand)]={}
        if JunctionDictD[(chr,D,strand)].has_key(A):
            pass
        else:
            JunctionDictD[(chr,D,strand)][A]={}
        for ID in IDfields:
            if JunctionDictD[(chr,D,strand)][A].has_key(ID):
                JunctionDictD[(chr,D,strand)][A][ID] += int(fields[ID])
            else:
                JunctionDictD[(chr,D,strand)][A][ID] = int(fields[ID])
        if JunctionDictA.has_key((chr,A,strand)):
            pass
        else:
            JunctionDictA[(chr,A,strand)]={}
        if JunctionDictA[(chr,A,strand)].has_key(D):
            pass
        else:
            JunctionDictA[(chr,A,strand)][D]={}
        for ID in IDfields:
            if JunctionDictA[(chr,A,strand)][D].has_key(ID):
                JunctionDictA[(chr,A,strand)][D][ID] += int(fields[ID])
            else:
                JunctionDictA[(chr,A,strand)][D][ID] = int(fields[ID])

    print 'finished inputting detected junctions'        

    keys = JunctionDictA.keys()
    keys.sort()
    for (chr,A,strand) in keys:
        for D in JunctionDictA[(chr,A,strand)].keys():
            outline = chr + '\t' + str(min(D,A)) + '\t' + str(max(D,A)) + '\t' + strand
            if JunctionDictA[(chr,A,strand)][D].has_key('known'):
                outline = outline + '\t' + 'known'
            else:
                outline = outline + '\t' + 'novel'
            outline = outline + '\t' + 'A'
            for ID in IDfields:
                TotalCounts = 0.0
                for DD in JunctionDictA[(chr,A,strand)].keys():
                    TotalCounts += JunctionDictA[(chr,A,strand)][DD][ID]
                outline = outline + '\t' + str(JunctionDictA[(chr,A,strand)][D][ID])
                outline = outline + '\t' + str(TotalCounts)
                if TotalCounts == 0:
                    psi = 'N/A'
                else:
                    psi = JunctionDictA[(chr,A,strand)][D][ID]/TotalCounts
                outline = outline + '\t' + str(psi)
            outfile.write(outline + '\n')

    keys = JunctionDictD.keys()
    keys.sort()
    for (chr,D,strand) in keys:
        for A in JunctionDictD[(chr,D,strand)].keys():
            outline = chr + '\t' + str(min(D,A)) + '\t' + str(max(D,A)) + '\t' + strand
            if JunctionDictD[(chr,D,strand)][A].has_key('known'):
                outline = outline + '\t' + 'known'
            else:
                outline = outline + '\t' + 'novel'
            outline = outline + '\t' + 'D'
            for ID in IDfields:
                TotalCounts = 0.0
                for AA in JunctionDictD[(chr,D,strand)].keys():
                    TotalCounts += JunctionDictD[(chr,D,strand)][AA][ID]
                outline = outline + '\t' + str(JunctionDictD[(chr,D,strand)][A][ID])
                outline = outline + '\t' + str(TotalCounts)
                if TotalCounts == 0:
                    psi = 'N/A'
                else:
                    psi = JunctionDictD[(chr,D,strand)][A][ID]/TotalCounts
                outline = outline + '\t' + str(psi)
            outfile.write(outline + '\n')

    outfile.close()
        
run()

