##################################
#                                #
# Last modified 2021/01/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outputfilename' % sys.argv[0]
        print '\tUse this script to convert GTF files with the following confirugation:'
        print '\t\tchrI . UTR5 136866 136914'
        print '\t\tchrI . exon 136914 137642'
        print '\t\tchrI . UTR3 137510 137642'
        print '\t\tchrI . CDS 136914 137510'
        print '\tinto:'
        print '\t\tchrI . exon 136866 137642'
        print '\t\tchrI . CDS 136914 137510'
        sys.exit(1)
    
    GTF = sys.argv[1]
    outfilename = sys.argv[2]

    TranscriptDict={}
    outfile = open(outfilename,'w')
    
    linelist = open(GTF)
    i=0
    for line in linelist:
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if line.startswith('#'):
            outfile.write(line)
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        start = int(fields[3])
        stop = int(fields[4])
        strand = fields[6]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=geneID
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        if TranscriptDict.has_key((geneName,geneID,transcriptID,transcriptName)):
            pass
        else:
            TranscriptDict[(geneName,geneID,transcriptID,transcriptName)] = {}
            TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['exon'] = {}
            TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['CDS'] = {}
            TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['chr'] = chr
            TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['strand'] = strand
        if fields[2] == 'CDS':
            for i in range(start,stop+1):
                TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['CDS'][i] = 1
        if fields[2] == 'exon':
            for i in range(start,stop+1):
                TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['exon'][i] = 1
        if fields[2] == 'UTR5':
            for i in range(start,stop+1):
                TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['exon'][i] = 1
        if fields[2] == 'UTR3':
            for i in range(start,stop+1):
                TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['exon'][i] = 1

    for (geneName,geneID,transcriptID,transcriptName) in TranscriptDict:
        chr = TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['chr']
        strand = TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['strand']
        CDSpositions = TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['CDS'].keys()
        CDSpositions.sort()
        exonpositions = TranscriptDict[(geneName,geneID,transcriptID,transcriptName)]['exon'].keys()
        exonpositions.sort()
        left = exonpositions[0]
        curPos = left
        for i in exonpositions[1:]:
            if i - curPos < 2:
                curPos = i
                continue
            else:
                outline = chr + '\t.\texon\t' + str(left) + '\t' +str(curPos) + '\t.\t' + strand + '\t.\tgene_id "' + geneID + '"; transcript_id "' + transcriptID + '"; gene_name "' + geneName + '"; transcript_name "' + transcriptName + '";'
                outfile.write(outline + '\n')
                left = i
                curPos = left
        outline = chr + '\t.\texon\t' + str(left) + '\t' +str(curPos) + '\t.\t' + strand + '\t.\tgene_id "' + geneID + '"; transcript_id "' + transcriptID + '"; gene_name "' + geneName + '"; transcript_name "' + transcriptName + '";'
        outfile.write(outline + '\n')
        if len(CDSpositions) == 0:
            continue
        left = CDSpositions[0]
        curPos = left
        for i in CDSpositions[1:]:
            if i - curPos < 2:
                curPos = i
                continue
            else:
                outline = chr + '\t.\tCDS\t' + str(left) + '\t' +str(curPos) + '\t.\t' + strand + '\t.\tgene_id "' + geneID + '"; transcript_id "' + transcriptID + '"; gene_name "' + geneName + '"; transcript_name "' + transcriptName + '";'
                outfile.write(outline + '\n')
                left = i
                curPos = left
        outline = chr + '\t.\tCDS\t' + str(left) + '\t' +str(curPos) + '\t.\t' + strand + '\t.\tgene_id "' + geneID + '"; transcript_id "' + transcriptID + '"; gene_name "' + geneName + '"; transcript_name "' + transcriptName + '";'
        outfile.write(outline + '\n')

    outfile.close()
   
run()
