##################################
#                                #
# Last modified 10/10/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s GTF annotationGTF maxExonLength max3UTRLength max_longest_internal_exon_ratio outfilename' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    annotation = sys.argv[2]
    maxExonLength = int(sys.argv[3])
    max3UTRength = int(sys.argv[4])
    maxExonLengthRatio = float(sys.argv[5])
    outfilename = sys.argv[6]

    AnnotationGeneDict={}
    lineslist  = open(annotation)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand = fields[6]
        gene_id = fields[8].split('gene_id "')[1].split('";')[0]
        transcript_id = fields[8].split('transcript_id "')[1].split('";')[0]
        if AnnotationGeneDict.has_key(gene_id):
            pass
        else:
            AnnotationGeneDict[gene_id]={}
        if AnnotationGeneDict[gene_id].has_key(transcript_id):
            pass
        else:
            AnnotationGeneDict[gene_id][transcript_id]=[]
        AnnotationGeneDict[gene_id][transcript_id].append((chr,left,right,strand))

    GeneDict={}
    lineslist  = open(gtf)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand = fields[6]
        gene_id = fields[8].split('gene_id "')[1].split('";')[0]
        transcript_id = fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict.has_key(gene_id):
            pass
        else:
            GeneDict[gene_id]={}
        if GeneDict[gene_id].has_key(transcript_id):
            pass
        else:
            GeneDict[gene_id][transcript_id]=[]
        GeneDict[gene_id][transcript_id].append((chr,left,right,strand))

    SkipDict = {}
    UTRDict = {}
    outfile = open(outfilename, 'w')

    i=len(GeneDict.keys())
    print i
    for geneID in GeneDict.keys():
        i-=1
        if i % 1000 == 0:
            print i
        longestExon=0
        AnnotationExons=[]
        if AnnotationGeneDict.has_key(geneID):
            pass
        else:
            continue
        for transcriptID in AnnotationGeneDict[geneID].keys():
            AnnotationGeneDict[geneID][transcriptID].sort()
            for (chr,left,right,strand) in AnnotationGeneDict[geneID][transcriptID][1:-1]:
                AnnotationExons.append((left,right))
            if len(AnnotationGeneDict[geneID][transcriptID]) < 3:
                continue
            for (chr,left,right,strand) in AnnotationGeneDict[geneID][transcriptID][1:-1]:
                length = right - left
                if length > longestExon:
                    longestExon = length
                    if longestExon > maxExonLength:
                        print longestExon, chr, left, right, strand, geneID, transcriptID
        AnnotationExons = list(Set(AnnotationExons))
        AnnotationExons.sort()
        PossibleRetainedIntronCoordinateDict={}
        for k in range(len(AnnotationExons)):
            (left1,right1) = AnnotationExons[k]
            for j in range(k+1,len(AnnotationExons)):
                (left2,right2) = AnnotationExons[j]
                if left2 > right1:
                    PossibleRetainedIntronCoordinateDict[(left1,right2)]=0
        for transcriptID in GeneDict[geneID].keys():
            e=0
            (chr,left,right,strand) = GeneDict[geneID][transcriptID][0]
            if strand == '-':
                GeneDict[geneID][transcriptID].reverse()
            for (chr,left,right,strand) in GeneDict[geneID][transcriptID]:
                e+=1
                if e != len(GeneDict[geneID][transcriptID]):
                     if math.fabs(right - left) > max(maxExonLength,maxExonLengthRatio*longestExon):
                         SkipDict[transcriptID]=[]
                         break
                if e == len(GeneDict[geneID][transcriptID]):
                     if (right - left) > max3UTRength:
                         UTRDict[transcriptID]={}
                         if strand == '+':
                             UTRDict[transcriptID][(chr,left,right,strand)] = (chr,left,left+max3UTRength,strand)
                         if strand == '-':
                             UTRDict[transcriptID][(chr,left,right,strand)] = (chr,right-max3UTRength,right,strand)
                if PossibleRetainedIntronCoordinateDict.has_key((left,right)):
                     SkipDict[transcriptID]=[]
                     break

    lineslist  = open(gtf)
    for line in lineslist:
        if line[0]=='#':
            outfile.write(line)
            continue
        fields=line.strip().split('\t')
        transcriptID = fields[8].split('transcript_id "')[1].split('"')[0]
        if SkipDict.has_key(transcriptID):
            continue
        elif UTRDict.has_key(transcriptID):
            chr=fields[0]
            left=int(fields[3])
            right=int(fields[4])
            strand = fields[6]
            if UTRDict[transcriptID].has_key((chr,left,right,strand)):
                (newchr,newleft,newright,newstrand) = UTRDict[transcriptID][(chr,left,right,strand)]
                outline = newchr + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(newleft) + '\t' + str(newright) + '\t' + fields[5] + '\t' + newstrand + '\t' + fields[7] + '\t' + fields[8]
                outfile.write(outline + '\n')
            else:
                outfile.write(line)
        else:
            outfile.write(line)


    outfile.close()
        
run()

