##################################
#                                #
# Last modified 2016/06/12       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s stringtie_merge.gtf annotation.gtf outputfilename [-removeUnstrandedTranscripts]' % sys.argv[0]
        sys.exit(1)

    stringtieGTF = sys.argv[1]
    annotationGTF = sys.argv[2]
    outfilename = sys.argv[3]

    doRemoveUnstrandedTranscripts = False
    if '-removeUnstrandedTranscripts' in sys.argv:
        print 'will remove unstranded transcripts'
        doRemoveUnstrandedTranscripts = True

    GeneDict = {}
    TranscriptToGeneIDDict = {}
    linelist=open(annotationGTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        TranscriptToGeneIDDict[transcriptID] = {}
        TranscriptToGeneIDDict[transcriptID]['geneID'] = geneID
        TranscriptToGeneIDDict[transcriptID]['geneName'] = geneName
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID] = {}
            GeneDict[geneID]['transcripts'] = {}
        GeneDict[geneID]['transcripts'][transcriptID] = 1

    print 'finished inputting annotation GTF'

    linelist = open(stringtieGTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID] = {}
            GeneDict[geneID]['transcripts'] = {}
        GeneDict[geneID]['transcripts'][transcriptID] = 1

    print 'finished first pass through stringtie GTF'

    for geneID in GeneDict.keys():
        IsInAnnotation = False
        for transcriptID in GeneDict[geneID]['transcripts']:
            if TranscriptToGeneIDDict.has_key(transcriptID):
                GeneDict[geneID]['realGeneID'] = TranscriptToGeneIDDict[transcriptID]['geneID']
                GeneDict[geneID]['realGeneName'] = TranscriptToGeneIDDict[transcriptID]['geneName']
                IsInAnnotation = True
                break
#        if geneID == 'MSTRG.30703':
#            print geneID, GeneDict[geneID]
#        print IsInAnnotation, geneID

    print 'finished reannotating genes'

    outfile = open(outfilename, 'w')

    linelist = open(stringtieGTF)
    for line in linelist:
        if line.startswith('#'):
            outfile.write(line)
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        if doRemoveUnstrandedTranscripts:
            strand = fields[6]
            if strand != '+' and strand != '-':
                continue
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        exon_number = fields[8].split('exon_number "')[1].split('";')[0]
#        if transcriptID == 'ENST00000370631.3':
#            print TranscriptToGeneIDDict[transcriptID]
        if TranscriptToGeneIDDict.has_key(transcriptID):
            realGeneID = TranscriptToGeneIDDict[transcriptID]['geneID']
            realGeneName = TranscriptToGeneIDDict[transcriptID]['geneName']
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t' + fields[4] + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t'
            outline = outline + 'gene_id "' + realGeneID
            outline = outline + '"; transcript_id "' + transcriptID
            outline = outline + '"; gene_name "' + realGeneName
            outline = outline + '"; exon_number "' + exon_number + '";'
            outfile.write(outline + '\n')
        elif GeneDict[geneID].has_key('realGeneID'):
            realGeneID = GeneDict[geneID]['realGeneID']
            realGeneName = GeneDict[geneID]['realGeneName']
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t' + fields[4] + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t'
            outline = outline + 'gene_id "' + realGeneID
            outline = outline + '"; transcript_id "' + transcriptID
            outline = outline + '"; gene_name "' + realGeneName
            outline = outline + '"; exon_number "' + exon_number + '";'
            outfile.write(outline + '\n')
        else:
            outfile.write(line)

    print 'finished printing new gtf'
        
    outfile.close()
   
run()
