##################################
#                                #
# Last modified 10/12/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def getChain(exonList):

    chain=[]
    exonList.sort()
    i=0
    for (chr,start,stop,strand) in exonList:
        i+=1
        if i==1:
            chain.append(stop)
            continue
        elif i==len(exonList):
            chain.append(start)
            continue
        else:
            chain.append(start)
            chain.append(stop)
    chain.sort()
    chain.append(strand)

    return(tuple(chain))

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s Cuffmerge_gtf annotation_gtf outputfilename' % sys.argv[0]
        sys.exit(1)
    
    CuffmergeGTF = sys.argv[1]
    AnnotationGTF = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')


    GeneDict={}
    SingleExonCoverageDict={}
    linelist = open(AnnotationGTF)
    for line in linelist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand = fields[6]
        gene_id = fields[8].split('gene_id "')[1].split('";')[0]
        transcript_id = fields[8].split('transcript_id "')[1].split('";')[0]
        if SingleExonCoverageDict.has_key(chr):
            pass
        else:
            SingleExonCoverageDict[chr]={}
        if GeneDict.has_key(gene_id):
            pass
        else:
            GeneDict[gene_id]={}
        if GeneDict[gene_id].has_key(transcript_id):
            pass
        else:
            GeneDict[gene_id][transcript_id]=[]
        GeneDict[gene_id][transcript_id].append((chr,left,right,strand))

    for geneID in GeneDict.keys():
        for transcriptID in GeneDict[gene_id].keys():
            if len(GeneDict[gene_id][transcript_id]) == 1:
                (chr,left,right,strand) = GeneDict[gene_id][transcript_id][0]
                for i in range(left,right):
                    SingleExonCoverageDict[chr][i]=0

    CuffmergeTranscriptDict={}

    linelist = open(CuffmergeGTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        outfile.write(line)
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        for i in range(start,stop):
            if SingleExonCoverageDict[chr].has_key(i):
                del SingleExonCoverageDict[chr][i]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if CuffmergeTranscriptDict.has_key(transcriptID):
            pass
        else:
            CuffmergeTranscriptDict[transcriptID]=[]
        CuffmergeTranscriptDict[transcriptID].append((chr,start,stop,strand))

    print 'finished inputting Cuffmerge GTF file'

    ChainDict={}
    for transcriptID in CuffmergeTranscriptDict.keys():
        chr = CuffmergeTranscriptDict[transcriptID][0][0]
        if ChainDict.has_key(chr):
            pass
        else:
            ChainDict[chr]={}
        CuffmergeTranscriptDict[transcriptID].sort()
        if len(CuffmergeTranscriptDict[transcriptID]) == 1:
            continue
        chain = getChain(CuffmergeTranscriptDict[transcriptID])
        ChainDict[chr][chain]=''

    print 'finished creating Cuffmerge chains'

    linelist = open(AnnotationGTF)
    AnnotationTranscriptDict={}
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if AnnotationTranscriptDict.has_key(transcriptID):
            pass
        else:
            AnnotationTranscriptDict[transcriptID]=[]
        AnnotationTranscriptDict[transcriptID].append((chr,start,stop,strand))

    print 'finished inputting Annotation GTF file'

    ToKeepDict={}
    for transcriptID in AnnotationTranscriptDict.keys():
        chr = AnnotationTranscriptDict[transcriptID][0][0]
        if ChainDict.has_key(chr):
            pass
        AnnotationTranscriptDict[transcriptID].sort()
        if len(AnnotationTranscriptDict[transcriptID]) == 1:
            continue
        chain = getChain(AnnotationTranscriptDict[transcriptID])
        if transcriptID == 'ENST00000303645.5':
            print chain
            if ChainDict[chr].has_key(chain):
                print '....'
                continue
        if ChainDict[chr].has_key(chain):
            continue
        else:
            ToKeepDict[transcriptID]=''

    for geneID in GeneDict.keys():
        for transcriptID in GeneDict[gene_id].keys():
            if len(GeneDict[gene_id][transcript_id]) == 1:
                (chr,left,right,strand) = GeneDict[gene_id][transcript_id][0]
                covered=0
                for i in range(left,right):
                    if SingleExonCoverageDict[chr].has_key(i):
                        covered+=1
                if covered == right-left:
                    ToKeepDict[transcriptID]=''


    print 'found', len(ToKeepDict.keys()), 'transcripts in annotation missing in Cuffmerge output'

    linelist = open(AnnotationGTF)
    AnnotationTranscriptDict={}
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if ToKeepDict.has_key(transcriptID):
            outfile.write(line)

    outfile.close()
   
run()