##################################
#                                #
# Last modified 09/23/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s GTF AnnotationGTF outputfilename' % sys.argv[0]
        sys.exit(1)
    
    novelgtf = sys.argv[1]
    knowngtf = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')

    GeneDict={}
    lineslist=open(knowngtf)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0] 
        geneName=fields[8].split('gene_name "')[1].split('";')[0]
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]={}
        if GeneDict[(geneID,geneName)].has_key(transcriptID):
            pass
        else:
            GeneDict[(geneID,geneName)][transcriptID]=[]
        GeneDict[(geneID,geneName)][transcriptID].append((chr,left,right,strand))

    KnownTSSDict={}
    KnownJunctionsDict={} 
    KnownGeneIDDict = {}

    for (geneID,geneName) in GeneDict.keys():
        KnownGeneIDDict[geneID]=''
        for transcriptID in GeneDict[(geneID,geneName)].keys():
            GeneDict[(geneID,geneName)][transcriptID].sort()
            strand = GeneDict[(geneID,geneName)][transcriptID][0][3]
            if strand == '+':
                TSS = (GeneDict[(geneID,geneName)][transcriptID][0][1],strand)
            if strand == '-':
                TSS = (GeneDict[(geneID,geneName)][transcriptID][-1][2],strand)
            chr = GeneDict[(geneID,geneName)][transcriptID][0][0]
            if KnownTSSDict.has_key(chr):
                pass
            else:
                KnownTSSDict[chr] = {}
                KnownJunctionsDict[chr] = {}
            if KnownTSSDict[chr].has_key(TSS):
                pass
            else:
                KnownTSSDict[chr][TSS]={}
                KnownTSSDict[chr][TSS]['genesIDs']=[]
                KnownTSSDict[chr][TSS]['genesNames']=[]
                KnownTSSDict[chr][TSS]['transcripts']=[]
            KnownTSSDict[chr][TSS]['genesIDs'].append(geneID)
            KnownTSSDict[chr][TSS]['genesNames'].append(geneName)
            KnownTSSDict[chr][TSS]['transcripts'].append(transcriptID)
            if len(GeneDict[(geneID,geneName)][transcriptID]) == 0:
                continue
            for p in range(len(GeneDict[(geneID,geneName)][transcriptID])):
                if p == 0:
                    KnownJunctionsDict[chr][GeneDict[(geneID,geneName)][transcriptID][p][2]]=0
                elif p == len(GeneDict[(geneID,geneName)][transcriptID]) - 1:
                    KnownJunctionsDict[chr][GeneDict[(geneID,geneName)][transcriptID][p][1]]=0
                else:
                    KnownJunctionsDict[chr][GeneDict[(geneID,geneName)][transcriptID][p][1]]=0
                    KnownJunctionsDict[chr][GeneDict[(geneID,geneName)][transcriptID][p][2]]=0

    GeneDict={}
    lineslist=open(novelgtf)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0] 
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]={}
        if GeneDict[(geneID,geneName)].has_key(transcriptID):
            pass
        else:
            GeneDict[(geneID,geneName)][transcriptID]=[]
        GeneDict[(geneID,geneName)][transcriptID].append((chr,left,right,strand))

    NovelTSSDict={}

    for (geneID,geneName) in GeneDict.keys():
        for transcriptID in GeneDict[(geneID,geneName)].keys():
            GeneDict[(geneID,geneName)][transcriptID].sort()
            strand = GeneDict[(geneID,geneName)][transcriptID][0][3]
            if strand == '+':
                TSS = (GeneDict[(geneID,geneName)][transcriptID][0][1],strand)
                junction = (GeneDict[(geneID,geneName)][transcriptID][0][2])
            elif strand == '-':
                TSS = (GeneDict[(geneID,geneName)][transcriptID][-1][2],strand)
                junction = (GeneDict[(geneID,geneName)][transcriptID][0][1])
            else:
               continue
            chr = GeneDict[(geneID,geneName)][transcriptID][0][0]
            if NovelTSSDict.has_key(chr):
                pass
            else:
                NovelTSSDict[chr] = {}
            if NovelTSSDict[chr].has_key(TSS):
                pass
            else:
                NovelTSSDict[chr][TSS]={}
                NovelTSSDict[chr][TSS]['genesIDs']=[]
                NovelTSSDict[chr][TSS]['genesNames']=[]
                NovelTSSDict[chr][TSS]['transcripts']=[]
                NovelTSSDict[chr][TSS]['junctions']=[]
            NovelTSSDict[chr][TSS]['genesIDs'].append(geneID)
            NovelTSSDict[chr][TSS]['genesNames'].append(geneName)
            NovelTSSDict[chr][TSS]['transcripts'].append(transcriptID)
            NovelTSSDict[chr][TSS]['junctions'].append(junction)

    outline='#chr\tTSS\tstrand\tNovelty_class\tgene_ID(s)\tgene_Names(s)\ttranscript_ID(s)\tnearest_TSS\tNearest_TSS_strand\tdistance_to_nearestTSS\tnearest_TSS_gene_ID(s)\tnearest_TSS_gene_Name(s)\tnearest_TSS_transcript_ID(s)'
    outfile.write(outline+'\n')

    chrkeys = NovelTSSDict.keys()
    chrkeys.sort()

    for chr in chrkeys:
        print chr
        keys = KnownTSSDict[chr].keys()
        keys.sort()
        novelkeys = NovelTSSDict[chr].keys()
        novelkeys.sort()
        for (NovelTSS,NovelStrand) in novelkeys:
            if KnownTSSDict.has_key(chr) and KnownTSSDict[chr].has_key((NovelTSS,NovelStrand)):
                continue
            NearestTSS = ''
            distance = 10000000000
            for (KnownTSS,KnownStrand) in keys:
                if KnownTSS > (NovelTSS + distance):
                    break
                if math.fabs(KnownTSS - NovelTSS) < distance:
                    distance = math.fabs(KnownTSS - NovelTSS)
                    NearestTSS = (KnownTSS,KnownStrand)
            for junction in NovelTSSDict[chr][(NovelTSS,NovelStrand)]['junctions']:
                if  KnownJunctionsDict[chr].has_key(junction):
                    novelty_class = 'extension'
                else:
                    novelty_class = 'intergenic'
                    for geneID in NovelTSSDict[chr][(NovelTSS,NovelStrand)]['genesIDs']:
                        if KnownGeneIDDict.has_key(geneID):
                             novelty_class = 'novel_upstream_exon'
            outline = chr + '\t' + str(NovelTSS) + '\t' + novelty_class + '\t' + NovelStrand + '\t'
            for geneID in NovelTSSDict[chr][(NovelTSS,NovelStrand)]['genesIDs']:
                outline = outline + geneID + ','
            outline = outline[0:-1] + '\t'
            for geneName in NovelTSSDict[chr][(NovelTSS,NovelStrand)]['genesNames']:
                outline = outline + geneName + ','
            outline = outline[0:-1] + '\t'
            for transcriptID in NovelTSSDict[chr][(NovelTSS,NovelStrand)]['transcripts']:
                outline = outline + transcriptID + ','
            outline = outline[0:-1] + '\t'
            if NovelStrand == '+':
                distance = NearestTSS[0] - NovelTSS 
            if NovelStrand == '-':
                distance = NovelTSS - NearestTSS[0]
            outline = outline + str(NearestTSS[0]) + '\t' + NearestTSS[1] + '\t' + str(distance) + '\t'
            for geneID in KnownTSSDict[chr][NearestTSS]['genesIDs']:
                outline = outline + geneID + ','
            outline = outline[0:-1] + '\t'
            for geneName in KnownTSSDict[chr][NearestTSS]['genesNames']:
                outline = outline + geneName + ','
            outline = outline[0:-1] + '\t'
            for transcriptID in KnownTSSDict[chr][NearestTSS]['transcripts']:
                outline = outline + transcriptID + ','
            outfile.write(outline[0:-1] + '\n')
             
    outfile.close()
   
run()
