##################################
#                                #
# Last modified 08/05/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s inputfilename gtf IDfield [genes | transcripts] outfile' % sys.argv[0]
        sys.exit(1)

    input = sys.argv[1]
    gtf = sys.argv[2]
    fieldID = int(sys.argv[3])
    type=sys.argv[4]
    outfilename = sys.argv[5]

    IDtoNameDict={}


    if type == 'genes':
        lineslist  = open(gtf)
        for line in lineslist:
            if line[0]=='#':
                continue
            fields=line.strip().split('\t')
            if fields[2]!='gene':
                continue
            name=fields[8].split('gene_name "')[1].split('";')[0]
            ID=fields[8].split('gene_id "')[1].split('";')[0]
            IDtoNameDict[ID]=name
        print 'finished inputting GTF file, found', len(IDtoNameDict.keys()), 'genes '
    if type == 'transcripts':
        lineslist  = open(gtf)
        for line in lineslist:
            if line[0]=='#':
                continue
            fields=line.strip().split('\t')
            if fields[2]!='transcript':
                continue
            name=fields[8].split('transcript_name "')[1].split('";')[0]
            ID=fields[8].split('transcript_id "')[1].split('";')[0]
            IDtoNameDict[ID]=name
        print 'finished inputting GTF file, found', len(IDtoNameDict.keys()), 'transcripts'

    outfile = open(outfilename, 'w')

    linelist=open(input)
    for line in linelist:
        if line.startswith('#'):
            outline='#name\tID\t'+line.split('#')[1].strip()
            outfile.write(outline+'\n')
            continue
        fields=line.strip().split('\t')
        ID=fields[fieldID]
        name=IDtoNameDict[ID]
        outfile.write(name + '\t' + line)

    outfile.close()
        
run()

