##################################
#                                #
# Last modified 2020/03/13       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf TIF-seq_file outfile ' % sys.argv[0]
        print '\tNote: TIF-seq file format:'
        print '\t\t#chr\tstrand\tgene\ttype\tmTIF_number\tmedian5\tmedian3\tsd5\tsd3'
        print '\t\t1\t-\tYAL067C\tVerified\t1\t33\t222\tNA\tNA'
        print '\tNote: the script will fail when there are exons in UTRs and when there are multiple transcripts per gene ID'
        sys.exit(1)

    GTF = sys.argv[1]
    TIF = sys.argv[2]
    outfile = sys.argv[3]

    GeneNameDict = {}

    j=0
    lineslist = open(GTF)
    TranscriptDict = {}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        if 'transcript_name' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        if TranscriptDict.has_key(geneID):
            pass
        else:
            TranscriptDict[geneID] = {}
        feature = fields[2]
        if TranscriptDict[geneID].has_key(feature):
            pass
        else:
            TranscriptDict[geneID][feature] = []
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        TranscriptDict[geneID][feature].append((left,right,chr,strand,geneName,transcriptName,transcriptID))

    lineslist = open(TIF)
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        geneID = fields[2]
        median5 = int(float(fields[5]))
        median3 = int(float(fields[6]))
        TranscriptDict[geneID]['TIF'] = [(median5,median3)]

    outfile = open(outfile, 'w')

    for geneID in TranscriptDict.keys():
        for feature in TranscriptDict[geneID].keys():
            TranscriptDict[geneID][feature].sort()
        strand = TranscriptDict[geneID]['exon'][0][3]
        if strand == '+':
            end5 = TranscriptDict[geneID]['exon'][0][0]
            end3 = TranscriptDict[geneID]['exon'][-1][1]
        if strand == '-':
            end3 = TranscriptDict[geneID]['exon'][0][0]
            end5 = TranscriptDict[geneID]['exon'][-1][1]
        if TranscriptDict[geneID].has_key('TIF'):
            (median5,median3) = TranscriptDict[geneID]['TIF'][0]
            (L1,R1,chr,strand,geneName,transcriptName,transcriptID) = TranscriptDict[geneID]['CDS'][0]
            (L2,R2,chr,strand,geneName,transcriptName,transcriptID) = TranscriptDict[geneID]['CDS'][-1]
            if strand == '+':
                if median5 > 0:
                    TranscriptDict[geneID]['exon'][0] = (L1 - median5,R1,chr,strand,geneName,transcriptName,transcriptID)
                    TranscriptDict[geneID]['UTR5'] = []
                    TranscriptDict[geneID]['UTR5'].append((TranscriptDict[geneID]['exon'][0][0],TranscriptDict[geneID]['CDS'][0][0],chr,strand,geneName,transcriptName,transcriptID))
                if median3 > 0:
                    (L,R,chr,strand,geneName,transcriptName,transcriptID) = TranscriptDict[geneID]['CDS'][-1]
                    TranscriptDict[geneID]['exon'][-1] = (L2,R2 + median3,chr,strand,geneName,transcriptName,transcriptID)
                    TranscriptDict[geneID]['UTR3'] = []
                    TranscriptDict[geneID]['UTR3'].append((TranscriptDict[geneID]['CDS'][-1][1],TranscriptDict[geneID]['exon'][-1][1],chr,strand,geneName,transcriptName,transcriptID))
            if strand == '-':
                if median3 > 0:
                    TranscriptDict[geneID]['exon'][0] = (L1 - median3,R1,chr,strand,geneName,transcriptName,transcriptID)
                    TranscriptDict[geneID]['UTR3'] = []
                    TranscriptDict[geneID]['UTR3'].append((TranscriptDict[geneID]['exon'][0][0],TranscriptDict[geneID]['CDS'][0][0],chr,strand,geneName,transcriptName,transcriptID))
                if median5 > 0:
                    TranscriptDict[geneID]['exon'][-1] = (L2,R2 + median5,chr,strand,geneName,transcriptName,transcriptID)
                    TranscriptDict[geneID]['UTR5'] = []
                    TranscriptDict[geneID]['UTR5'].append((TranscriptDict[geneID]['CDS'][-1][1],TranscriptDict[geneID]['exon'][-1][1],chr,strand,geneName,transcriptName,transcriptID))
        for feature in TranscriptDict[geneID].keys():
            if feature == 'TIF':
                continue
            TranscriptDict[geneID][feature] = list(Set(TranscriptDict[geneID][feature]))
            for (left,right,chr,strand,geneName,transcriptName,transcriptID) in TranscriptDict[geneID][feature]:
                outline = chr + '\t'
                outline = outline + '.' + '\t'
                outline = outline + feature + '\t'
                outline = outline + str(left) + '\t'
                outline = outline + str(right) + '\t'
                outline = outline + '.' + '\t'
                outline = outline + strand + '\t'
                outline = outline + '.' + '\t'
                outline = outline + 'gene_id "' + geneID + '"; transcript_id "' + transcriptID + '"; gene_name "' + geneName + '"; transcript_name "' + transcriptName + '";'
                outfile.write(outline + '\n')

    outfile.close()
   
run()
