##################################
#                                #
# Last modified 05/16/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s skipped_exons junctions_table junction_count_fields minCounts outfile' % sys.argv[0]
        print '       junction_count_fields format: either comma separated, or start:end (including start and end, 0-based)'
        print '       junction_table format: chr left right strand .... counts'
        print '       skipped_exons format: ENSG00000001617.7	SEMA3F	chr3	50211783	50214201	+	chr3	50211783	50212529	+	chr3	50212621	50214201	+'
        print '       Inclusion score will be calculated as ((junctions1 + junctions2)/2)/(junctions1+2 + junctions1 + junctions2)'
        sys.exit(1)

    skipped_exons = sys.argv[1]
    counts_table = sys.argv[2]
    minimumCounts = int(sys.argv[4])
    outfile = open(sys.argv[5],'w')

    valueFields=[]
    if ':' in sys.argv[3]:
        fields = sys.argv[3].split(':')
        start = int(fields[0])
        end = int(fields[1])
        for f in range(start,end+1):
            valueFields.append(f)
    else:
        fields = sys.argv[3].split(',')
        for f in fields:
            valueFields.append(int(f))
    valueFields.sort()

    JunctionsDict={}
    listoflines = open(counts_table)
    ValueDict={}
    Labels = []
    for line in listoflines:
        fields=line.strip().split('\t')
        if line.startswith('#'):
            for v in valueFields:
                ValueDict[v] = fields[v]
                Labels.append(fields[v])
            continue
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        JunctionsDict[(chr,left,right,strand)]={}
        for v in valueFields:
            JunctionsDict[(chr,left,right,strand)][ValueDict[v]] = int(fields[v])

    Labels.sort()

    outline = 'GeneIDd\tGeneName\tchr\tleft\tright\tstrand\tchr\tleft\tright\tstrand\tchr\tleft\tright\tstrand'
    for label in Labels:
        outline = outline + '\t' + label
    outfile.write(outline + '\n')

    listoflines = open(skipped_exons)
    for line in listoflines:
        if line.startswith('#'):
             continue
        fields=line.split('\t')
        chr=fields[2]
        strand=fields[5]
        left12=int(fields[4])
        right12=int(fields[11])
        left1=int(fields[4])
        right1=int(fields[7])
        left2=int(fields[8])
        right2=int(fields[11])
        outline = line.strip()
        for label in Labels:
            if JunctionsDict.has_key((chr,left12,right12,strand)):
                Counts12 = JunctionsDict[(chr,left12,right12,strand)][label]
            else:
                Counts12 = 0
            if JunctionsDict.has_key((chr,left1,right1,strand)):
                Counts1 = JunctionsDict[(chr,left1,right1,strand)][label]
            else:
                Counts1 = 0
            if JunctionsDict.has_key((chr,left2,right2,strand)):
                Counts2 = JunctionsDict[(chr,left2,right2,strand)][label]
            else:
                Counts2 = 0
            if Counts12 + Counts1 + Counts2 >= minimumCounts:
                inclusionScore = ((Counts1 + Counts2)/2.0)/(Counts12 + (Counts1 + Counts2)/2.0)
            else:
                inclusionScore = 'N\A'
            outline = outline + '\t' + str(inclusionScore)
        outfile.write(outline + '\n')

    outfile.close()
 
run()

