##################################
#                                #
# Last modified 2022/12/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def SubChains(N):
    Subsequences = []
    for i in range(N-1):
        a = [i]
        for j in range(i+1,N):
            b = a + [j]
            Subsequences.append(b)
            a.append(j)

    return Subsequences

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf1 referenceGTF outputfilename [-minExonLength bp] [-minIntronLength bp]' % sys.argv[0]
        print '\tNote: transcripts that contain the same intron chain will be collapsed (though still listed separately) for the reference GTF, but not for the first GTF!!!'
        print '\tNote: the [-minExonLength bp] and [-minIntronLength bp] options apply to the querry GTF file'
        sys.exit(1)

    GTF1 = sys.argv[1]
    GTF2 = sys.argv[2]
    outfilename = sys.argv[3]

    minIL = 1
    if '-minIntronLength' in sys.argv:
        minIL = int(sys.argv[sys.argv.index('-minIntronLength') + 1])
        print 'will skip introns shorter than', minIL

    minEL = 1
    if '-minExonLength' in sys.argv:
        minEL = int(sys.argv[sys.argv.index('-minExonLength') + 1])
        print 'will skip exons shorter than', minEL

    PreliminaryTranscriptDict1={}
    linelist=open(GTF1)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        if right == left:
            right += 1
        if math.fabs(right - left) <= minEL:
            continue
        strand = fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if PreliminaryTranscriptDict1.has_key(chr):
            pass
        else:
            PreliminaryTranscriptDict1[chr]={}
        if PreliminaryTranscriptDict1[chr].has_key(transcriptID):
            pass
        else:
            PreliminaryTranscriptDict1[chr][transcriptID]=[]
        PreliminaryTranscriptDict1[chr][transcriptID].append((chr,left,right,strand))

    print 'finished inputting', GTF1

    TranscriptDict2={}
    linelist=open(GTF2)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict2.has_key(chr):
            pass
        else:
            TranscriptDict2[chr]={}
        if TranscriptDict2[chr].has_key(transcriptID):
            pass
        else:
            TranscriptDict2[chr][transcriptID]=[]
        TranscriptDict2[chr][transcriptID].append((chr,left,right,strand))

    print 'finished inputting', GTF2

    ChainDict1={}
    ChainDict2={}

    TranscriptDict1 = {}
    for chr in PreliminaryTranscriptDict1.keys():
        TranscriptDict1[chr] = {}
        for transcriptID in PreliminaryTranscriptDict1[chr].keys():
            TranscriptDict1[chr][transcriptID] = []
            PreliminaryTranscriptDict1[chr][transcriptID].sort()
            NewCoordinates = []
            (chr,left1,right1,strand) = PreliminaryTranscriptDict1[chr][transcriptID][0]
            NewCoordinates.append(left1)
            NewCoordinates.append(right1)
            for i in range(1,len(PreliminaryTranscriptDict1[chr][transcriptID])):
                (chr,left,right,strand) = PreliminaryTranscriptDict1[chr][transcriptID][i]
                if math.fabs(left - NewCoordinates[-1]) <= minIL:
#                     print left, right, math.fabs(left - NewCoordinates[-1]) 
#                     print NewCoordinates
                     NewCoordinates[-1] = right
#                     print NewCoordinates
                else:
                     NewCoordinates.append(left)
                     NewCoordinates.append(right)
            for i in range(0,len(NewCoordinates),2):
                TranscriptDict1[chr][transcriptID].append((chr,NewCoordinates[i],NewCoordinates[i+1],strand))

    for chr in TranscriptDict1.keys():
        for transcriptID in TranscriptDict1[chr].keys():
            if len(TranscriptDict1[chr][transcriptID]) == 1:
                continue
            chr = TranscriptDict1[chr][transcriptID][0][0]
            strand = TranscriptDict1[chr][transcriptID][0][3]
            if ChainDict1.has_key((chr,strand)):
                pass
            else:
                ChainDict1[(chr,strand)]={}
            TranscriptDict1[chr][transcriptID].sort()
            i=0
            FullCoordinates=[]
            chain=[]
            for (chr,left,right,strand) in TranscriptDict1[chr][transcriptID]:
                i+=1
                FullCoordinates.append(left)
                FullCoordinates.append(right)
                if i == 1:
                    if len(TranscriptDict1[chr][transcriptID]) == i:
                        chain.append(str(left))
                    chain.append(str(right))
                elif (i > 1) and (i == len(TranscriptDict1[chr][transcriptID])):
                    chain.append(str(left))
                else:
                    chain.append(str(left))
                    chain.append(str(right))
            finalchain = ','.join(chain)
            ChainDict1[(chr,strand)][transcriptID] = (finalchain,FullCoordinates)

    for chr in TranscriptDict2.keys():
        for transcriptID in TranscriptDict2[chr].keys():
            if len(TranscriptDict2[chr][transcriptID]) == 1:
                continue
            chr = TranscriptDict2[chr][transcriptID][0][0]
            strand = TranscriptDict2[chr][transcriptID][0][3]
            if ChainDict2.has_key((chr,strand)):
                pass
            else:
                ChainDict2[(chr,strand)]={}
            TranscriptDict2[chr][transcriptID].sort()
            i=0
            FullCoordinates=[]
            chain=[]
            for (chr,left,right,strand) in TranscriptDict2[chr][transcriptID]:
                i+=1
                FullCoordinates.append(left)
                FullCoordinates.append(right)
                if i == 1:
                    if len(TranscriptDict2[chr][transcriptID]) == i:
                        chain.append(str(left))
                    chain.append(str(right))
                elif (i > 1) and (i == len(TranscriptDict2[chr][transcriptID])):
                    chain.append(str(left))
                else:
                    chain.append(str(left))
                    chain.append(str(right))
            finalchain = ','.join(chain)
            if ChainDict2[(chr,strand)].has_key(finalchain):
                pass
            else:
                ChainDict2[(chr,strand)][finalchain] = []
                ChainDict2[(chr,strand)][finalchain].append((transcriptID,FullCoordinates))

    keys1 = ChainDict1.keys()
    keys2 = ChainDict2.keys()

    keys= keys1 + keys2
    keys = list(Set(keys))
    keys.sort()

    outfile = open(outfilename, 'w')
    outfile.write("#chr\tTranscriptID1\tNumExons\tTranscriptID2\tNumExons\tMatch(Full/Partial)\tChain1\tChain2" + '\n')

    for (chr,strand) in keys:
        if ChainDict1.has_key((chr,strand)):
            pass
        else:
            continue
        if ChainDict2.has_key((chr,strand)):
            pass
        else:
            for transcriptID in ChainDict1[(chr,strand)].keys():
                outline = chr + '\t' + transcriptID + '\t' + str(len(ChainDict1[(chr,strand)][transcriptID][1])/2) + '\t-\t-\t-\t' + ChainDict1[(chr,strand)][transcriptID][0] + '\t' + '-'
                outfile.write(outline + '\n')
            continue
        for transcriptID in ChainDict1[(chr,strand)].keys():
            T1chain = ChainDict1[(chr,strand)][transcriptID][0]
            if ChainDict2[(chr,strand)].has_key(T1chain):
                outline = chr + '\t' + transcriptID + '\t' + str(len(ChainDict1[(chr,strand)][transcriptID][1])/2) + '\t'
                for (refTID,FCs) in ChainDict2[(chr,strand)][T1chain]:
                    outline = outline + refTID + ','
                outline = outline[:-1] + '\t' + str(len(ChainDict1[(chr,strand)][transcriptID][1])/2) + '\tF\t' + ChainDict1[(chr,strand)][transcriptID][0] + '\t' +  ChainDict1[(chr,strand)][transcriptID][0]
                outfile.write(outline + '\n')
            else:
                matches = []
                for refChain in ChainDict2[(chr,strand)].keys():
                    if T1chain in refChain:
                        matches.append(refChain)
                if len(matches) > 0:
                    outline = chr + '\t' + transcriptID + '\t' + str(len(ChainDict1[(chr,strand)][transcriptID][1])/2) + '\t'
                    for refChain in matches:
                        for (refTID,FCs) in ChainDict2[(chr,strand)][refChain]:
                            outline = outline + refTID + ','
                    outline = outline[:-1] + '\t'
                    for refChain in matches:
                        for (refTID,FCs) in ChainDict2[(chr,strand)][refChain]:
                            outline = outline + str(len(FCs)/2) + ',' 
                    outline = outline[:-1] + '\tP\t' + ChainDict1[(chr,strand)][transcriptID][0] + '\t' 
                    for refChain in matches:
                        for (refTID,FCs) in ChainDict2[(chr,strand)][refChain]:
                            outline = outline + ChainDict1[(chr,strand)][transcriptID][0] + '|'
                    outfile.write(outline[:-1] + '\n')
                else:
                    outline = chr + '\t' + transcriptID + '\t' + str(len(ChainDict1[(chr,strand)][transcriptID][1])/2) + '\t-\t-\t-\t' + ChainDict1[(chr,strand)][transcriptID][0] + '\t' + '-'
                    outfile.write(outline + '\n')
                
        
    outfile.close()
   
run()
