##################################
#                                #
# Last modified 03/27/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s annotation_gtf TSS_file ORF outfilename' % sys.argv[0]
        print 'TSS format: #chr	TSS	strand	000-distance	000-geneIDs	000-geneNames	000-novelty_class	000-transcripts'
        print 'ORF format: #GeneID GeneName        TranscriptID    TranscriptName  transcriptType  chr     LeftPos,RightPos        orientation     Start_codon_pos,Stop_codon_pos  RNA_length      protein_length  protein'
        sys.exit(1)
    
    GTF = sys.argv[1]
    TSS = sys.argv[2]
    ORF = sys.argv[3]
    outfilename = sys.argv[4]

    KnownTranscriptDict={}
    
    ORFDict={}

    linelist = open(ORF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.split('\n')[0].split('\t')
        transcriptID = fields[2]
        try:
            ORF = fields[11]
        except:
            print line.strip()
        ORFDict[transcriptID] = ORF

    linelist = open(GTF)
    i=0
    for line in linelist:
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if KnownTranscriptDict.has_key(geneID):
            pass
        else:
            KnownTranscriptDict[geneID]={}
        KnownTranscriptDict[geneID][transcriptID]=''

    outfile = open(outfilename,'w')
    outfile.write('#chr\tTSS\tstrand\tdistance\tgeneID(s)\tgeneName(s)\tnovely_class\ttranscript(s)\tORF_change\n')

    linelist = open(TSS)
    f=0
    t=0
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        geneIDs = fields[4].split(',')
        annotatedORFs={}
        ff=0
        for geneID in geneIDs:
            if KnownTranscriptDict.has_key(geneID):
                pass
            else:
                continue
            for transcriptID in KnownTranscriptDict[geneID].keys():
                t+=1
                try:
                    annotatedORFs[ORFDict[transcriptID]]=''
                except:
                    f+=1
                    ff='f'
                    print transcriptID, 'not found in ORF file', f, t
        novelORFs=[]
        for transcriptID in fields[7].split(','):
            novelORFs.append(ORFDict[transcriptID])
        NewORF='yes'
        for nORF in novelORFs:
            if annotatedORFs.has_key(nORF):
                NewORF = 'no'
        outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t' + fields[4] + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + NewORF
        outfile.write(outline + '\t' + str(ff) + '\n')
   
    outfile.close()
   
run()
