##################################
#                                #
# Last modified 12/14/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s junctions1 junctions2 junctions-catalogue maxCounts outfilename ' % sys.argv[0]
        print 'junctions files format: chr     left    right   strand  total   staggered       5-exon  5-intron|3-intron       3-exon' 
        print 'junctions catalogue file format: chr1    14829   14929   -       TT      GT|AG   GC              10      novel   known exon to unknown internal exon     WASH5P  WASH5P' 
        sys.exit(1)

    junctions1 = sys.argv[1]
    junctions2 = sys.argv[2]
    junctionscatalogue = sys.argv[3]
    max = int(sys.argv[4])
    outputfilename = sys.argv[5]

    outfile = open(outputfilename, 'w')

    Junctions1Dict={}
    lineslist = open(junctions1)
    for line in lineslist:
        fields = line.strip().split('\t')
        if fields[1] == 'left':
            continue
        chr=fields[0]     
        left=fields[1]
        right=fields[2]
        strand=fields[3]     
        type=fields[7]
        staggered=int(fields[5])
        Junctions1Dict[(chr,left,right,strand,type)]=staggered

    Junctions2Dict={}
    lineslist = open(junctions2)
    for line in lineslist:
        fields = line.strip().split('\t')
        if fields[1] == 'left':
            continue
        chr=fields[0]     
        left=fields[1]
        right=fields[2]
        strand=fields[3]     
        type=fields[7]
        staggered=int(fields[5])
        Junctions2Dict[(chr,left,right,strand,type)]=staggered

    JunctionsDict={}

    for (chr,left,right,strand,type) in Junctions1Dict:
        JunctionsDict[(chr,left,right,strand,type)]={}
        JunctionsDict[(chr,left,right,strand,type)]['rep1']=Junctions1Dict[(chr,left,right,strand,type)]
        if Junctions2Dict.has_key((chr,left,right,strand,type)):
            JunctionsDict[(chr,left,right,strand,type)]['rep2']=Junctions2Dict[(chr,left,right,strand,type)]
        else:
            JunctionsDict[(chr,left,right,strand,type)]['rep2']=0

    for (chr,left,right,strand,type) in Junctions2Dict:
        if JunctionsDict.has_key((chr,left,right,strand,type)):
            pass
        else:
            JunctionsDict[(chr,left,right,strand,type)]={}
            JunctionsDict[(chr,left,right,strand,type)]['rep2']=Junctions2Dict[(chr,left,right,strand,type)]
            JunctionsDict[(chr,left,right,strand,type)]['rep1']=0

    lineslist = open(junctionscatalogue)
    for line in lineslist:
        fields = line.strip().split('\t')
        chr=fields[0]     
        left=fields[1]
        right=fields[2]
        strand=fields[3]     
        type=fields[5]
        if JunctionsDict.has_key((chr,left,right,strand,type)):
            JunctionsDict[(chr,left,right,strand,type)]['novelty']=fields[9]
            JunctionsDict[(chr,left,right,strand,type)]['connectivity']=fields[10]

    keyDict={}
    keyDict['type']={}
    keyDict['novelty']={}
    keyDict['connectivity']={}

    OutputDict={}
    for i in range(0,max+1):
        OutputDict[i]={}
        OutputDict[i]['type']={}
        OutputDict[i]['novelty']={}
        OutputDict[i]['connectivity']={}
    for (chr,left,right,strand,type) in JunctionsDict.keys():
        minCounts=min(JunctionsDict[(chr,left,right,strand,type)]['rep1'],JunctionsDict[(chr,left,right,strand,type)]['rep2'])
        novelty=JunctionsDict[(chr,left,right,strand,type)]['novelty']
        connectivity=JunctionsDict[(chr,left,right,strand,type)]['connectivity']
        if type != 'GT|AG' and type != 'GC|AG' and type != 'AT|AC':
            type = 'other'
        if keyDict['type'].has_key(type):
            pass
        else:
            keyDict['type'][type]=''
            for i in range(0,max+1):
                OutputDict[i]['type'][type]=0
        if keyDict['novelty'].has_key(novelty):
            pass
        else:
            keyDict['novelty'][novelty]=''
            for i in range(0,max+1):
                OutputDict[i]['novelty'][novelty]=0
        if keyDict['connectivity'].has_key(connectivity):
            pass
        else:
            keyDict['connectivity'][connectivity]=''
            for i in range(0,max+1):
                OutputDict[i]['connectivity'][connectivity]=0
        for i in range(0,max+1):
            if minCounts>=i:
                OutputDict[i]['connectivity'][connectivity]+=1
                OutputDict[i]['novelty'][novelty]+=1
                OutputDict[i]['type'][type]+=1

    outline='minCounts'
    types=keyDict['type'].keys()
    types.sort()
    for type in types:
        outline=outline+'\t'+type
    noveltyClasses=keyDict['novelty'].keys()
    noveltyClasses.sort()
    for novelty in noveltyClasses:
        outline=outline+'\t'+novelty
    connectivityClasses=keyDict['connectivity'].keys()
    connectivityClasses.sort()
    for connectivity in connectivityClasses:
        outline=outline+'\t'+connectivity
    outfile.write(outline+'\n')

    Ikeys=OutputDict.keys()
    Ikeys.sort()

    for i in Ikeys:
        outline=str(i)
        for type in types:
            outline=outline+'\t'+str(OutputDict[i]['type'][type])
        for novelty in noveltyClasses:
            outline=outline+'\t'+str(OutputDict[i]['novelty'][novelty])
        for connectivity in connectivityClasses:
            outline=outline+'\t'+str(OutputDict[i]['connectivity'][connectivity])
        print outline
        outfile.write(outline+'\n')
            
    outfile.close()

run()

