##################################
#                                #
# Last modified 06/01/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import string

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outfilename' % sys.argv[0]
        sys.exit(1)

    gtf=sys.argv[1]
    outputfilename = sys.argv[2]

    j=0
    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneID = geneName
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        T = (chr,strand,geneID,geneName,transcriptID,transcriptName)
        if TranscriptDict.has_key(T):
            pass
        else:
            TranscriptDict[T]=[]
        TranscriptDict[T].append((left,right))

    outfile = open(outputfilename, 'w')
    outline = '#geneID\tgeneName\ttranscriptID\ttranscriptName\tchr\tstrand\texon_position\tleft\tright\texon_length\tphasing'
    outfile.write(outline + '\n')

    for T in TranscriptDict.keys():
        if len(TranscriptDict[T]) == 1:
            continue
        (chr,strand,geneID,geneName,transcriptID,transcriptName) = T
        TranscriptDict[T].sort()
        if strand == '-':
            TranscriptDict[T].reverse()
        outline = geneID + '\t' + geneName + '\t' + transcriptID + '\t' + transcriptName + '\t' + chr + '\t' + strand
        for i in range(len(TranscriptDict[T])):
            i+=1
            (left,right) = TranscriptDict[T][i-1]
            length = (right + 1 - left)
            phase = length % 3
            if i == 1:
                pos = 'first'
            elif i == len(TranscriptDict[T]):
                pos = 'last'
            else:
                pos = 'middle'
            outfile.write(outline + '\t' + pos + '\t' + str(left) + '\t' + str(right) + '\t' + str(length) + '\t' + str(phase) + '\n')

    outfile.close()

run()

