##################################
#                                #
# Last modified 08/10/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s <junctions file> <transcript models gtf filename> <outfile prefix>' % sys.argv[0]
        print '	Format of junctions file:'
        print '	chr	left	right	strand	total_reads	staggered_reads'
        print '	Example:'
        print '	chr1	14830	14970	-	127	51'
        print '=====================================================================================================================================================' 
        print '	Note: the script takes a list of junctions and an annotation and compares the junctions against the annotation when looking for retained introns and exon skipping.' 
        print '	It uses only junctions when looking for alternative 5 and 3 splice sites and exons'
        print '	It looks only at annotation for mutually exclusive exons'
        print '	If expression is to be taken into account, supply an annotation that contains only genes isoforms for which expression information is available'
        print '=====================================================================================================================================================' 

        sys.exit(1)

    junctions = sys.argv[1]
    gtf = sys.argv[2]
    outputfileprefix = sys.argv[3]

    GeneExonDict={}
    PosGeneDict={}
    FirstExonDict={}
    LastExonDict={}
    MiddleExonDict={}
    listoflines = open(gtf)
    print 'parsing gtf file'
    i=0
    AfterTranscriptField=False
    for line in listoflines:
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields=line.split('\t')
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        try:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        except:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        try:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        except:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        if fields[2]=='gene':
            currentGene=geneName
            continue
        if fields[2]=='transcript':
            currentTranscript=transcriptName
            AfterTranscriptField=True
            continue
        if fields[2]!='exon':
            continue
        if GeneExonDict.has_key(geneName):
            GeneExonDict[geneName].append((chr,left,right,strand))
        else:
            GeneExonDict[geneName]=[]
            GeneExonDict[geneName].append((chr,left,right,strand))
        if PosGeneDict.has_key((chr,left)):
            PosGeneDict[(chr,left)].append(geneName)
        else:
            PosGeneDict[(chr,left)]=[]
            PosGeneDict[(chr,left)].append(geneName)
        if PosGeneDict.has_key((chr,right)):
            PosGeneDict[(chr,right)].append(geneName)
        else:
            PosGeneDict[(chr,right)]=[]
            PosGeneDict[(chr,right)].append(geneName)
        if AfterTranscriptField:
            if strand=='+':
                FirstExonDict[(chr,left,right,strand)]=''
            if strand=='-':
                LastExonDict[(chr,left,right,strand)]=''
            AfterTranscriptField=False
        else:
            MiddleExonDict[(chr,left,right,strand)]=''
    print 'finished parsing gtf file'

    print 'First exons before fitlering', len(FirstExonDict.keys())
    print 'Last exons before fitlering', len(LastExonDict.keys())
    for (chr,left,right,strand) in FirstExonDict.keys():
        if MiddleExonDict.has_key((chr,left,right,strand)):
            del FirstExonDict[(chr,left,right,strand)]

    for (chr,left,right,strand) in LastExonDict.keys():
        if MiddleExonDict.has_key((chr,left,right,strand)):
            del LastExonDict[(chr,left,right,strand)]

    print 'First exons after fitlering', len(FirstExonDict.keys())
    print 'Last exons after fitlering', len(LastExonDict.keys())

    for geneName in GeneExonDict.keys():
        GeneExonDict[geneName]=list(Set(GeneExonDict[geneName]))

    for (chr,right) in PosGeneDict.keys():
        PosGeneDict[(chr,right)]=list(Set(PosGeneDict[(chr,right)]))

    print 'parsing juncitons file'

    listoflines = open(junctions)
    junctionsList=[]
    for line in listoflines:
        fields=line.split('\t')
        chr=fields[0]
        left=int(fields[1])-1
        right=int(fields[2])
        strand=fields[3]
        total=int(fields[4])
        staggered=int(fields[5])
        junctionsList.append((chr,left,right,strand,total,staggered))

    print 'looking for intergenic junctions'    
    outfile = open(outputfileprefix + '.intergenic', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\tleft_gene\tright_gene\n')
    for (chr,left,right,strand,total,staggered) in junctionsList:
        if PosGeneDict.has_key((chr,left)) and PosGeneDict.has_key((chr,right)):
            geneSet = PosGeneDict[(chr,left)] + PosGeneDict[(chr,right)]
            geneSet=list(Set(geneSet))
            leftGenes=''
            for gene in PosGeneDict[(chr,left)]:
                leftGenes=leftGenes+gene+' '
            rightGenes=''
            for gene in PosGeneDict[(chr,right)]:
                rightGenes=rightGenes+gene+' '
        else: 
            if PosGeneDict.has_key((chr,left)):
                leftGenes=''
                rightGenes=''                            
                for gene in PosGeneDict[(chr,left)]:
                    leftGenes=leftGenes+gene+' '
            elif PosGeneDict.has_key((chr,right)):
                rightGenes=''
                leftGenes=''                            
                for gene in PosGeneDict[(chr,right)]:
                    rightGenes=rightGenes+gene+' '
            else:
                rightGenes=''
                leftGenes=''                            
            outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'+leftGenes+'\t'+rightGenes
            outfile.write(outline+'\n')
            continue
        if len(geneSet) == len(PosGeneDict[(chr,left)]) + len(PosGeneDict[(chr,right)]):
            outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'+leftGenes+'\t'+rightGenes
            outfile.write(outline+'\n')
    outfile.close()

    print 'looking for exon skipping events'    
    outfile = open(outputfileprefix+'.exonSkipping', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\tGene(s)\tskipped_exon(s)\n')
    for (chr,left,right,strand,total,staggered) in junctionsList:
        if PosGeneDict.has_key((chr,left)) and PosGeneDict.has_key((chr,right)):
            geneSet = PosGeneDict[(chr,left)] + PosGeneDict[(chr,right)]
            geneSet=list(Set(geneSet))
            skippedExons=[]
            for geneName in geneSet:
                for (Exonchr,Exonleft,Exonright,Exonstrand) in GeneExonDict[geneName]:
                    if FirstExonDict.has_key((Exonchr,Exonleft,Exonright,Exonstrand)): 
                        continue
                    if LastExonDict.has_key((Exonchr,Exonleft,Exonright,Exonstrand)): 
                        continue
                    if Exonchr!=chr and Exonstrand!=strand:
                        continue
                    if (Exonleft > left and Exonleft < right) and (Exonright > left and Exonright < right):
                        skipped=chr+':'+str(Exonleft)+'-'+str(Exonright)+strand
                        skippedExons.append(skipped)
            if len(skippedExons) > 0:
                outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'
                for gene in geneSet:
                    outline=outline + gene + ' '
                outline=outline[0:-1]+'\t'
                for exon in skippedExons:
                    outline=outline+exon + ' '
                outline=outline[0:-1]
                outfile.write(outline+'\n')
    outfile.close()

    print 'looking for retained introns'    
    outfile = open(outputfileprefix+'.retainedIntrons', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\tGene\tcontaining_exon\n')
    for (chr,left,right,strand,total,staggered) in junctionsList:
        if PosGeneDict.has_key((chr,left)) and PosGeneDict.has_key((chr,right)):
            geneSet = PosGeneDict[(chr,left)] + PosGeneDict[(chr,right)]
            geneSet=list(Set(geneSet))
            for geneName in geneSet:
                for (Exonchr,Exonleft,Exonright,Exonstrand) in GeneExonDict[geneName]:
                    if Exonleft < left and Exonright > right:
                        ContainingExon=chr+':'+str(Exonleft)+'-'+str(Exonright)+strand
                        outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'+ContainingExon+'\t'
                        for gene in geneSet:
                            outline=outline+' '+gene
                        outfile.write(outline+'\n')
                        break
    outfile.close()

    junctionDict={}
    for (chr,left,right,strand,total,staggered) in junctionsList:
        if junctionDict.has_key(chr): 
            junctionDict[chr].append((chr,left,right,strand,total,staggered))
        else: 
            junctionDict[chr]=[]
            junctionDict[chr].append((chr,left,right,strand,total,staggered))

    print 'looking for alternative 5` splice sites'    
    outfile = open(outputfileprefix+'.alt5splice', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\tGene\n')
    FoundListAlt5=[]
    for chr in junctionDict.keys():
        print chr
        for (chr1,left1,right1,strand1,total1,staggered1) in junctionDict[chr]:
            for (chr2,left2,right2,strand2,total2,staggered2) in junctionDict[chr]:
                if strand1 != strand2:
                    continue
                if strand == '+' and left1 != left2 and right1 == right2:
                    FoundListAlt5.append((chr,left2,right2,strand2,total2,staggered2))
                    FoundListAlt5.append((chr,left1,right1,strand1,total1,staggered1))
                if strand == '-' and right1 != right2 and left1 == left2:
                    FoundListAlt5.append((chr,left1,right1,strand1,total1,staggered1))
                    FoundListAlt5.append((chr,left2,right2,strand2,total2,staggered2))
    FoundListAlt5=list(Set(FoundListAlt5))
    FoundListAlt5.sort()
    for (chr,left,right,strand,total,staggered) in FoundListAlt5:
         outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'
         try:
             geneSet = PosGeneDict[(chr,left1)] + PosGeneDict[(chr,right1)]
         except:
             geneSet=['unknown']
         geneSet=list(Set(geneSet))
         for gene in geneSet:
             outline=outline+' '+gene
         outfile.write(outline+'\n')

    outfile.close()

    print 'looking for alternative 3` splice sites'    
    outfile = open(outputfileprefix+'.alt3splice', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\tGene\n')
    FoundListAlt3=[]
    for chr in junctionDict.keys():
        print chr
        for (chr1,left1,right1,strand1,total1,staggered1) in junctionDict[chr]:
            for (chr2,left2,right2,strand2,total2,staggered2) in junctionDict[chr]:
                if strand1 != strand2:
                    continue
                if strand == '+' and right1 != right2 and left1 == left2:
                    FoundListAlt3.append((chr,left2,right2,strand2,total2,staggered2))
                    FoundListAlt3.append((chr,left1,right1,strand1,total1,staggered1))
                if strand == '-' and left1 != left2 and right1 == right2:
                    FoundListAlt3.append((chr,left1,right1,strand1,total1,staggered1))
                    FoundListAlt3.append((chr,left2,right2,strand2,total2,staggered2))
    FoundListAlt3=list(Set(FoundListAlt3))
    FoundListAlt3.sort()
    for (chr,left,right,strand,total,staggered) in FoundListAlt3:
         outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'
         try:
             geneSet = PosGeneDict[(chr,left1)] + PosGeneDict[(chr,right1)]
         except:
             geneSet=['unknown']
         geneSet=list(Set(geneSet))
         for gene in geneSet:
             outline=outline+' '+gene
         outfile.write(outline+'\n')
    outfile.close()

    print 'looking for alternative 5` exons'    
    outfile = open(outputfileprefix+'.alt5exons', 'w')
    outfile.write('#chr\tleft\tright\tstrand\ttotal_read\tstaggered_reads\t5-exon\tGene\n')
    found={}
    for (chr,left,right,strand,total,staggered) in FoundListAlt5:
         try:
             geneSet = PosGeneDict[(chr,left)] + PosGeneDict[(chr,right)]
         except:
             continue
         if strand=='+':
             dictKey=(chr,right)
         if strand=='-':
             dictKey=(chr,left)
         if found.has_key(dictKey):
             pass
         else:
             found[dictKey]=[]
         for gene in geneSet:
             for (Exonchr,Exonleft,Exonright,Exonstrand) in GeneExonDict[gene]:
                 if ((Exonstrand == '+' and left == Exonright) or (Exonstrand == '-' and right == Exonleft)):
                     if FirstExonDict.has_key((Exonchr,Exonleft,Exonright,Exonstrand)):
                         found[dictKey].append((Exonchr+':'+str(Exonleft)+'-'+str(Exonright)+Exonstrand,chr,left,right,strand,total,staggered))
    keys=found.keys()
    keys.sort()
    for (chr,pos) in keys:
         found[(chr,pos)]=list(Set(found[(chr,pos)]))
         if len(found[(chr,pos)]) >= 2:
             geneSet = PosGeneDict[(chr,pos)]
             genes=''
             for gene in geneSet:
                 genes=genes+gene+' '
             for (exon,chr,left,right,strand,total,staggered) in found[(chr,pos)]:
                 outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand+'\t'+str(total)+'\t'+str(staggered)+'\t'+exon+'\t'+genes
                 outfile.write(outline+'\n')                 
    outfile.close()

run()

