##################################
#                                #
# Last modified 06/01/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s gtf1 gtf2 GTFcomparison_output outputfileprefix' % sys.argv[0]
        print 'Note: the script will try to update the transcripts in the second gtf file with the transcripts in the first'
        sys.exit(1)

    GTF1 = sys.argv[1]
    GTF2 = sys.argv[2]
    GTFcomparison = sys.argv[3]
    outfileprefix = sys.argv[4]

    TranscriptCorrespondenceDict = {}
    TranscriptCorrespondenceDictPartial = {}
    linelist=open(GTFcomparison)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[0] != '-':
            if fields[2] != '-':
                dist5 = int(fields[4])
                dist3 = int(fields[5])
                TranscriptCorrespondenceDict[fields[0]] = (-dist5,-dist3)
                TranscriptCorrespondenceDict[fields[2]] = (dist5,dist3)
            if fields[2] == '-':
                if fields[6] != '-':
                    TranscriptCorrespondenceDictPartial[fields[0]] = fields[6]
                    TranscriptCorrespondenceDictPartial[fields[6]] = fields[0]
        if fields[0] == '-':
            if fields[8] != '-':
                TranscriptCorrespondenceDictPartial[fields[8]] = fields[2]
                TranscriptCorrespondenceDictPartial[fields[2]] = fields[8]


    TranscriptDict1 = {}
    linelist=open(GTF1)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict1.has_key(chr):
            pass
        else:
            TranscriptDict1[chr]={}
        if TranscriptDict1[chr].has_key(transcriptID):
            pass
        else:
            TranscriptDict1[chr][transcriptID]=[]
        TranscriptDict1[chr][transcriptID].append((chr,left,right,strand))

    print 'finished inputting', GTF1

    TranscriptDict2 = {}
    linelist=open(GTF2)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict2.has_key(chr):
            pass
        else:
            TranscriptDict2[chr]={}
        if TranscriptDict2[chr].has_key((geneID,transcriptID)):
            pass
        else:
            TranscriptDict2[chr][(geneID,transcriptID)]=[]
        TranscriptDict2[chr][(geneID,transcriptID)].append((chr,left,right,strand))

    print 'finished inputting', GTF2

    ExtensionDict = {}

    chromosomes = TranscriptDict2.keys()
    for chr in chromosomes:
        for (geneID,transcriptID) in TranscriptDict2[chr].keys():
            TranscriptDict2[chr][(geneID,transcriptID)].sort()
            EVIDENCE = '-'
            if TranscriptCorrespondenceDict.has_key(transcriptID):
                EVIDENCE = 'Full intron chain reconstruction'
                (dist5,dist3) = TranscriptCorrespondenceDict[transcriptID]
            elif TranscriptCorrespondenceDictPartial.has_key(transcriptID):
                EVIDENCE = 'Partial intron chain reconstruction'
                MT = TranscriptCorrespondenceDictPartial[transcriptID]
                strand = TranscriptDict2[chr][(geneID,transcriptID)][0][3]
                (dist5,dist3) = (0,0)
                if TranscriptDict1[chr][MT][0][2] == TranscriptDict2[chr][(geneID,transcriptID)][0][2]:
                    if strand == '+':
                        dist5 = TranscriptDict1[chr][MT][0][1] - TranscriptDict2[chr][(geneID,transcriptID)][0][1]
                    if strand == '-':
                        dist3 = TranscriptDict2[chr][(geneID,transcriptID)][0][1] - TranscriptDict1[chr][MT][0][1]
                if TranscriptDict1[chr][MT][-1][1] == TranscriptDict2[chr][(geneID,transcriptID)][-1][1]:
                    if strand == '+':
                        dist3 = TranscriptDict1[chr][MT][-1][2] - TranscriptDict2[chr][(geneID,transcriptID)][-1][2]
                    if strand == '-':
                        dist5 = TranscriptDict2[chr][(geneID,transcriptID)][-1][2] - TranscriptDict1[chr][MT][-1][2]
            else:
                (dist5,dist3) = (0,0)
            ExtensionDict[(geneID,transcriptID)] = (dist5,dist3,len(TranscriptDict2[chr][(geneID,transcriptID)]),EVIDENCE)

    outfileGTF = open(outfileprefix + '.gtf','w')
    outfileExt = open(outfileprefix + '.extensions','w')

    outline = '#5p/3p\tgeneID\ttranscriptID\tchr\tleft\tright\tstrand\tchr\tnewleft\tnewright\text_chr\text_left\text_right\text_length\tevidence\tNumExons'
    outfileExt.write(outline + '\n')

    linelist=open(GTF2)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] == 'CDS':
            outfileGTF.write(line)
            continue
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        (dist5,dist3,TL,EVIDENCE) = ExtensionDict[(geneID,transcriptID)]
#        if transcriptID == 'CAK78296':
#            print (dist5,dist3,TL)
        if TL == 1:
            if strand == '+':
                if dist5 < 0:
                    newleft = left + dist5
                else:
                    newleft = left
                if dist3 > 0:
                    newright = right + dist3
                else:
                    newright = right
                if newleft != left:
                    outlineE1 = '5p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(newright) + '\t' + chr + '\t' + str(newleft)  + '\t' + str(left) + '\t' + str(left - newleft) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
                if newright != right:
                    outlineE2 = '3p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(newright) + '\t' + chr + '\t' + str(right)  + '\t' + str(newright) + '\t' + str(newright - right) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE2 + '\n')
            if strand == '-':
                if dist3 > 0:
                    newleft = left - dist3
                else:
                    newleft = left
                if dist5 < 0:
                    newright = right - dist5
                else:
                    newright = right
                if newleft != left:
                    outlineE1 = '3p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(newright) + '\t' + chr + '\t' + str(newleft)  + '\t' + str(left) + '\t' + str(left - newleft) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
                if newright != right:
                    outlineE2 = '5p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(newright) + '\t' + chr + '\t' + str(right)  + '\t' + str(newright) + '\t' + str(newright - right) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE2 + '\n')
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(newleft) + '\t' + str(newright) + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + fields[8]
        elif TL > 1 and TranscriptDict2[chr][(geneID,transcriptID)].index((chr,left,right,strand)) == 0:
            if strand == '+':
                if dist5 < 0:
                    newleft = left + dist5
                else:
                    newleft = left
                if newleft != left:
                    outlineE1 = '5p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(right) + '\t' + chr + '\t' + str(newleft)  + '\t' + str(left) + '\t' + str(left - newleft) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
            if strand == '-':
                if dist3 > 0:
                    newleft = left - dist3
                else:
                    newleft = left
                if newleft != left:
                    outlineE1 = '3p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(newleft) + '\t' + str(right) + '\t' + chr + '\t' + str(newleft)  + '\t' + str(left) + '\t' + str(left - newleft) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(newleft) + '\t' + str(right) + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + fields[8]
        elif TL > 1 and TranscriptDict2[chr][(geneID,transcriptID)].index((chr,left,right,strand)) == TL - 1:
            if strand == '+':
                if dist3 > 0:
                    newright = right + dist3
                else:
                    newright = right
                if newright != right:
                    outlineE1 = '3p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(left) + '\t' + str(newright) + '\t' + chr + '\t' + str(right)  + '\t' + str(newright) + '\t' + str(newright - right) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
            if strand == '-':
                if dist5 < 0:
                    newright = right - dist5
                else:
                    newright = right
                if newright != right:
                    outlineE1 = '5p' + '\t' + geneID + '\t' + transcriptID + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + chr +  '\t' + str(left) + '\t' + str(newright) + '\t' + chr + '\t' + str(right)  + '\t' + str(newright) + '\t' + str(newright - right) + '\t' + EVIDENCE + '\t' + str(TL)
                    outfileExt.write(outlineE1 + '\n')
            outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(left) + '\t' + str(newright) + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + fields[8]
        else:
            outline = line.strip()
        outfileGTF.write(outline + '\n')  
        
    outfileGTF.close()
    outfileExt.close()
   
run()
