##################################
#                                #
# Last modified 12/29/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s gtf ORF annotation_gtf outputfilename' % sys.argv[0]
        sys.exit(1)

    GTF = sys.argv[1]
    ORF = sys.argv[2]
    annotation = sys.argv[3]
    outfilename = sys.argv[4]

    KnownTranscriptDict={}

    linelist=open(annotation)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptType=fields[8].split('transcript_type "')[1].split('";')[0]
        KnownTranscriptDict[transcriptID]=transcriptType

    TranscriptDict={}

    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        if 'transcript_type "' in sys.argv:
            continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'transcript_type "' in fields[8]:
            continue
        class_code=fields[8].split('class_code "')[1].split('";')[0]
        if class_code == 'u':
            continue
        nearestRefID=fields[8].split('nearest_ref "')[1].split('";')[0]
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if TranscriptDict.has_key(transcriptID):
            pass
        else:
            TranscriptDict[transcriptID]={}
            TranscriptDict[transcriptID]['exons']=[]
            TranscriptDict[transcriptID]['class_code']=class_code
            TranscriptDict[transcriptID]['nearestRefID']=nearestRefID
            if class_code == '=':
                TranscriptDict[transcriptID]['transcriptType'] = KnownTranscriptDict[nearestRefID]
        TranscriptDict[transcriptID]['exons'].append((chr,left,right,strand))

    linelist=open(ORF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcriptID = fields[2]
        if TranscriptDict.has_key(transcriptID):
            ORF = fields[8]
            ORFleft = int(fields[8].split(',')[0])
            ORFright = int(fields[8].split(',')[1])
            strand = TranscriptDict[transcriptID]['exons'][0][3]
            TranscriptDict[transcriptID]['exons'].sort()
            if strand == '-':
                TranscriptDict[transcriptID]['exons'].reverse()
            exonLengths = []
            for (chr,left,right,strand) in TranscriptDict[transcriptID]['exons']:
                exonLengths.append(right-left)
            LengthWithoutLastExon=0
            for i in range(len(exonLengths)-1):
                LengthWithoutLastExon += exonLengths[i]
            if LengthWithoutLastExon > ORFright + 50:
                TranscriptDict[transcriptID]['transcriptType'] = 'nonsense_mediated_decay'
            else:
                TranscriptDict[transcriptID]['transcriptType'] = 'protein_coding'

    outfile = open(outfilename, 'w')

    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if 'transcript_type "' in fields[8]:
            outfile.write(line)
            continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        try:
            class_code=fields[8].split('class_code "')[1].split('";')[0]
        except:
            print line
            sys.exit(1) 
        if class_code == 'u':
            outfile.write(line)
            continue
        outline = line.strip() + ' transcript_type "' + TranscriptDict[transcriptID]['transcriptType'] + '";'
        outfile.write(outline + '\n')

    outfile.close()
   
run()
