##################################
#                                #
# Last modified 10/14/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def addnulls(counts,num_positions):

    nulls=''
    for i in range(num_positions-len(str(counts))):
        nulls += '0'

    nulls += str(counts)

    return nulls

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s GTF annotationGTF outfilename' % sys.argv[0]
        sys.exit(1)

    annotation = sys.argv[2]
    GTF = sys.argv[1]
    outfilename = sys.argv[3]

    TranscriptToGeneIDDict={}
    AnnotationGeneToGeneIDDict={}
    lineslist  = open(annotation)
    t=0
    for line in lineslist:
        t+=1
        if t % 1000000 == 0:
            print t, 'lines processed'
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        ID=fields[8].split('gene_id "')[1].split('";')[0]
        name=fields[8].split('gene_name "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        TranscriptToGeneIDDict[transcriptID]=ID
        if AnnotationGeneToGeneIDDict.has_key(name):
            pass
        else:
            AnnotationGeneToGeneIDDict[name]={}
        if AnnotationGeneToGeneIDDict[name].has_key(ID):
            pass
        else:
            AnnotationGeneToGeneIDDict[name][ID]=''

    print 'finished importing annotation'

    outfile = open(outfilename, 'w')

    t=0

    gene_idCount=0
    transcript_idCount=0
    tss_idCount=0
    p_idCount=0

    gene_idDict={}
    gene_idDict[0]=''
    transcript_idDict={}
    transcript_idDict[0]=''
    tss_idDict={}
    tss_idDict[0]=''
    p_idDict={}
    p_idDict[0]=''

    lastChr=''

    lineslist  = open(GTF)
    for line in lineslist:
        t+=1
        if t % 1000000 == 0:
            print t, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[0]
        outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t' + fields[4] + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t'
        gene_id = fields[8].split('gene_id "')[1].split('";')[0]
        transcript_id = fields[8].split('transcript_id "')[1].split('";')[0]
        if 'exon_number' in fields[8]:
            exon_number = fields[8].split('exon_number "')[1].split('";')[0]
        else:
            exon_number = ''
#        oId = fields[8].split('oId "')[1].split('";')[0]
        if 'nearest_ref' in fields[8]:
            nearest_ref = fields[8].split('nearest_ref "')[1].split('";')[0]
        else:
            nearest_ref = ''
        if 'class_code' in fields[8]:
            class_code = fields[8].split('class_code "')[1].split('";')[0]
        else:
            class_code = ''
        if 'gene_name' in fields[8]:
            gene_name = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            gene_name=''    
        if class_code == 'j' or class_code == '=':
            gene_id = TranscriptToGeneIDDict[nearest_ref]
        elif gene_name != '' and AnnotationGeneToGeneIDDict.has_key(gene_name) and len(AnnotationGeneToGeneIDDict[gene_name].keys())==1:
            gene_id = AnnotationGeneToGeneIDDict[gene_name].keys()[0]
        else:
            if chr == lastChr and gene_idDict[gene_idCount] == gene_id:
                gene_id = 'XLOC_' + addnulls(gene_idCount,6)
            else:
                gene_idCount += 1
                gene_idDict[gene_idCount] = gene_id
                gene_id = 'XLOC_' + addnulls(gene_idCount,6)
        if chr == lastChr and transcript_idDict[transcript_idCount] == transcript_id:
            transcript_id = 'TCONS_' + addnulls(transcript_idCount,8)
        else:
            transcript_idCount += 1
            transcript_idDict[transcript_idCount] = transcript_id
            transcript_id = 'TCONS_' + addnulls(transcript_idCount,8)
        if fields[2]=='transcript':
            if gene_name != '':
                outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '"; gene_name "' + gene_name + '"; nearest_ref "' + nearest_ref + '";'
            else:
                if nearest_ref != '':
                    outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '"; nearest_ref "' + nearest_ref + '";'
                else:
                    outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '";'
        else:
            if gene_name != '':
                outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '"; exon_number "' + exon_number + '"; gene_name "' + gene_name + '"; nearest_ref "' + nearest_ref + '";'
            else:
                if nearest_ref != '':
                    outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '"; exon_number "' + exon_number + '"; nearest_ref "' + nearest_ref + '";'
                else:
                    outline = outline + 'gene_id "' + gene_id + '"; transcript_id "' + transcript_id + '"; exon_number "' + exon_number + '";'
        if class_code != '':
            outline = outline + ' class_code "' + class_code + '";'
        if 'tss_id' in fields[8]:
            tss_id = fields[8].split('tss_id "')[1].split('";')[0]
            if chr == lastChr and tss_idDict[tss_idCount] == tss_id:
                tss_id = 'TSS' + str(tss_idCount)
            else:
                tss_idCount += 1
                tss_idDict[tss_idCount] = tss_id
                tss_id = 'TSS' + str(tss_idCount)
            outline = outline + ' tss_id "' + tss_id + '";'
        if 'p_id' in fields[8]:
            p_id = fields[8].split('p_id "')[1].split('";')[0]
            if chr == lastChr and p_idDict[p_idCount] == p_id:
                p_id = 'P' + str(p_idCount)
            else:
                p_idCount += 1
                p_idDict[p_idCount] = p_id
                p_id = 'P' + str(p_idCount)
            outline = outline + ' p_id "' + p_id + '";'
        lastChr=chr        
        outfile.write(outline+'\n')

    outfile.close()
        
run()

