##################################
#                                #
# Last modified 10/31/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s annotation.gtf gtf exon_table outputfilename' % sys.argv[0]
        print '       table format: chr left right strand gene(s) transcript(s) FPKM1  etc.'
        sys.exit(1)

    annotationgtf = sys.argv[1]
    gtf = sys.argv[2]
    table = sys.argv[3]
    outputfilename = sys.argv[4]

    GeneDict={}

    print 'inputing annotation gtf'
    lineslist = open(annotationgtf)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        gene=fields[8].split('gene_id "')[1].split('";')[0]
        transcript=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict.has_key(gene):
            pass
        else:
            GeneDict[gene]={}
        if GeneDict[gene].has_key(transcript):
            GeneDict[gene][transcript].append((chr,left,right,strand))
        else:
            GeneDict[gene][transcript]=[]
            GeneDict[gene][transcript].append((chr,left,right,strand))

    print 'inputing gtf'
    lineslist = open(gtf)
    GeneDict2={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        gene=fields[8].split('gene_id "')[1].split('";')[0]
        transcript=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict2.has_key(gene):
            pass
        else:
            GeneDict2[gene]={}
        if GeneDict2[gene].has_key(transcript):
            GeneDict2[gene][transcript].append((chr,left,right,strand))
        else:
            GeneDict2[gene][transcript]=[]
            GeneDict2[gene][transcript].append((chr,left,right,strand))
    
    print 'compiling exons list'
    ExonDict={}
    for gene in GeneDict2.keys():
        for transcript in GeneDict2[gene].keys():
            GeneDict2[gene][transcript].sort()
            i=0
            for (chr,left,right,strand) in GeneDict2[gene][transcript]:
                i+=1
                if len(GeneDict2[gene][transcript]) == 1:
                     ExonDict[(chr,left,right,strand)] = 'monoexonic'
                elif ExonDict.has_key((chr,left,right,strand)):
                    if ExonDict[(chr,left,right,strand)] == 'internal':
                        continue
                    else:
                        if i == 1:
                            if strand == '+':
                                 ExonDict[(chr,left,right,strand)] = '5UTR'
                            if strand == '-':
                                 ExonDict[(chr,left,right,strand)] = '3UTR'
                        elif i == len(GeneDict2[gene][transcript]):
                            if strand == '-':
                                 ExonDict[(chr,left,right,strand)] = '5UTR'
                            if strand == '+':
                                 ExonDict[(chr,left,right,strand)] = '3UTR'
                        else:
                            ExonDict[(chr,left,right,strand)] = 'internal'
                else:
                    if i == 1:
                        if strand == '+':
                             ExonDict[(chr,left,right,strand)] = '5UTR'
                        if strand == '-':
                             ExonDict[(chr,left,right,strand)] = '3UTR'
                    elif i == len(GeneDict2[gene][transcript]):
                        if strand == '-':
                             ExonDict[(chr,left,right,strand)] = '5UTR'
                        if strand == '+':
                             ExonDict[(chr,left,right,strand)] = '3UTR'
                    else:
                        ExonDict[(chr,left,right,strand)] = 'internal'

    outfile = open(outputfilename, 'w')
    outline = '#chr\tleft\tright\tstrand\tgeneID\ttranscriptID\tMatchType'
    lineslist = open(table)
    i=0
    for line in lineslist:
        i+=1
        if i % 10000 == 0:
            print i
        fields = line.strip().split('\t')
        if line.startswith('#'):
            for ID in range(6,len(fields)):
                outline=outline + '\t' + fields[ID]
            outfile.write(outline.strip() + '\n')
            continue
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        genes = fields[4].split(',')
        knownGenes=[]
        for gene in genes:
            if GeneDict.has_key(gene):
                knownGenes.append(gene)
        match = ''
        if len(knownGenes) == 0:
            match = 'intergenic'
        else:
            for gene in knownGenes:
                if match == 'exact':
                    break
                for transcript in GeneDict[gene].keys():
                    GeneDict[gene][transcript].sort()
                    if match == 'exact':
                        break
                    for (chrA,leftA,rightA,strandA) in GeneDict[gene][transcript]:
                        if (chrA,leftA,rightA,strandA) == (chr,left,right,strand):
                            match = 'exact'
                            break
                        elif leftA == left or rightA == right:
                            if rightA - leftA < right - left:
                                if GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == 0:
                                    if strand == '+' and match != 'internal_exon_extension':
                                        match = '5UTR_extension'
                                    if strand == '-' and match != 'internal_exon_extension':
                                        match = '3UTR_extension'
                                elif GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == len(GeneDict[gene][transcript])-1:
                                    if strand == '-' and match != 'internal_exon_extension':
                                        match = '5UTR_extension'
                                    if strand == '+' and match != 'internal_exon_extension':
                                        match = '3UTR_extension'
                                else:
                                    match = 'internal_exon_extension'
                            if rightA - leftA > right - left:
                                if GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == 0:
                                    if strand == '+' and match != 'internal_exon_shortening':
                                        match = '5UTR_shortening'
                                    if strand == '-' and match != 'internal_exon_shortening':
                                        match = '3UTR_shortening'
                                elif GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == len(GeneDict[gene][transcript])-1:
                                    if strand == '-' and match != 'internal_exon_shortening':
                                        match = '5UTR_shortening'
                                    if strand == '+' and match != 'internal_exon_shortening':
                                        match = '3UTR_shortening'
                                else:
                                    match = 'internal_exon_shortening'
                            break
                        elif rightA > right and leftA < left:
                            if GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == 0:
                                if strand == '+' and match != 'internal_exon_double_shortening':
                                    match = '5UTR_double_shortening'
                                if strand == '-' and match != 'internal_exon_double_shortening':
                                     match = '3UTR_double_shortening'
                            elif GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == len(GeneDict[gene][transcript])-1:
                                if strand == '-' and match != 'internal_exon_double_shortening':
                                    match = '5UTR_double_shortening'
                                if strand == '+' and match != 'internal_exon_double_shortening':
                                    match = '3UTR_double_shortening'
                            else:
                                match = 'internal_exon_double_shortening'
                            break
                        elif rightA < right and leftA > left:
                            if GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == 0:
                                if strand == '+' and match != 'internal_exon_double_extension':
                                    match = '5UTR_double_extension'
                                if strand == '-' and match != 'internal_exon_double_extension':
                                    match = '3UTR_double_extension'
                            elif GeneDict[gene][transcript].index((chrA,leftA,rightA,strandA)) == len(GeneDict[gene][transcript])-1:
                                if strand == '-' and match != 'internal_exon_double_extension':
                                    match = '5UTR_double_extension'
                                if strand == '+' and match != 'internal_exon_double_extension':
                                    match = '3UTR_double_extension'
                            else:
                                 match = 'internal_exon_double_extension'
                            break
            if match == '':
                if ExonDict[(chr,left,right,strand)] == 'internal':
                    match = 'novel_internal'
                if ExonDict[(chr,left,right,strand)] == '5UTR':
                    match = '5UTR'
                if ExonDict[(chr,left,right,strand)] == '3UTR':
                    match = 'novel_3UTR'
                if ExonDict[(chr,left,right,strand)] = 'monoexonic':
                    match = 'intergenic'
        outline = ''
        for ID in range(0,6):
            outline = outline + fields[ID] + '\t'
        outline = outline + match + '\t'
        for ID in range(6,len(fields)):
            outline = outline + fields[ID] + '\t'
        outfile.write(outline.strip() + '\n')
        
    outfile.close()

run()

