##################################
#                                #
# Last modified 11/15/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outfile' % sys.argv[0]
        sys.exit(1)

    inputfilename = sys.argv[1]
    outfile = open(sys.argv[2],'w')

    listoflines = open(inputfilename)
    GeneDict={}
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.split('\t')
        if fields[2] != 'exon':
            continue
        GeneName=fields[8].split('gene_name "')[1].split('";')[0]
        GeneID=fields[8].split('gene_id "')[1].split('";')[0]
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        chr=fields[0]
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        if GeneDict.has_key((GeneID,GeneName)):
            pass
        else:
            GeneDict[(GeneID,GeneName)]={}
        if GeneDict[(GeneID,GeneName)].has_key(TranscriptID):
            pass
        else:
            GeneDict[(GeneID,GeneName)][TranscriptID]=[]
        GeneDict[(GeneID,GeneName)][TranscriptID].append((chr,left,right,strand))

    SkippedExons=[]
    JunctionsDict={}
    JunctionsList=[]

    for (GeneID,GeneName) in GeneDict.keys():
        for TranscriptID in GeneDict[(GeneID,GeneName)].keys():
            GeneDict[(GeneID,GeneName)][TranscriptID]=list(Set(GeneDict[(GeneID,GeneName)][TranscriptID]))
            GeneDict[(GeneID,GeneName)][TranscriptID].sort()
            if len(GeneDict[(GeneID,GeneName)][TranscriptID]) < 2:
                continue
            for i in range(len(GeneDict[(GeneID,GeneName)][TranscriptID])-1):
                chr=GeneDict[(GeneID,GeneName)][TranscriptID][i][0]
                strand=GeneDict[(GeneID,GeneName)][TranscriptID][i][3]
                JunctionsList.append((chr,GeneDict[(GeneID,GeneName)][TranscriptID][i][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+1][1],strand))

    JunctionsList=list(Set(JunctionsList))
    for (chr,left,right,strand) in JunctionsList:
        JunctionsDict[(chr,left,right,strand)]=''

    for (GeneID,GeneName) in GeneDict.keys():
        FirstExons={}
        Junctions=[]
        if len(GeneDict[(GeneID,GeneName)].keys())<2:
            continue
        for TranscriptID in GeneDict[(GeneID,GeneName)].keys():
            strand=GeneDict[(GeneID,GeneName)][TranscriptID][0][3]
            GeneDict[(GeneID,GeneName)][TranscriptID].sort()
            FirstExons[GeneDict[(GeneID,GeneName)][TranscriptID][0]]=''
            FirstExons[GeneDict[(GeneID,GeneName)][TranscriptID][-1]]=''
        for TranscriptID in GeneDict[(GeneID,GeneName)].keys():
            chr=GeneDict[(GeneID,GeneName)][TranscriptID][0][0]
            strand=GeneDict[(GeneID,GeneName)][TranscriptID][0][3]
            if len(GeneDict[(GeneID,GeneName)][TranscriptID]) < 2:
                continue
            for i in range(len(GeneDict[(GeneID,GeneName)][TranscriptID])-2):
                junction1=(chr,GeneDict[(GeneID,GeneName)][TranscriptID][i][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+1][1],strand)
                junction2=(chr,GeneDict[(GeneID,GeneName)][TranscriptID][i+1][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+2][1],strand)
                junctionSkipped=(chr,GeneDict[(GeneID,GeneName)][TranscriptID][i][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+2][1],strand)
                if JunctionsDict.has_key(junctionSkipped):
                    if FirstExons.has_key((chr,GeneDict[(GeneID,GeneName)][TranscriptID][i+1][1],GeneDict[(GeneID,GeneName)][TranscriptID][i+1][2],strand)):
                        continue
                    else:
                        skippedExon=(GeneID,GeneName,GeneDict[(GeneID,GeneName)][TranscriptID][i][0],GeneDict[(GeneID,GeneName)][TranscriptID][i][1],GeneDict[(GeneID,GeneName)][TranscriptID][i][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+1][1],GeneDict[(GeneID,GeneName)][TranscriptID][i+1][2],GeneDict[(GeneID,GeneName)][TranscriptID][i+2][1],GeneDict[(GeneID,GeneName)][TranscriptID][i+2][2],strand)
                        SkippedExons.append(skippedExon)

    outline = '#GeneID\tGeneName\tchr\texon1_left\texon1_right\tstrand\tchr\texon2_left\texon2_right\tchr\tstrand\texon3_left\texon3_right\tstrand\n'
    outfile.write(outline)

    SkippedExons=list(Set(SkippedExons))
    SkippedExons.sort()
    print 'found', len(SkippedExons), 'skipped exons'
    for (GeneID,GeneName,chr,exon1left,exon1right,exon2left,exon2right,exon3left,exon3right,strand) in SkippedExons:
        outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n' % (GeneID,GeneName,chr,exon1left-1,exon1right-1,strand,chr,exon2left-1,exon2right-1,strand,chr,exon3left-1,exon3right-1,strand)
        outfile.write(outline)

    outfile.close()
 
run()

