##################################
#                                #
# Last modified 10/06/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf blast outfile [-NoHitminLen minLen]' % sys.argv[0]
        print '\tNote: assumed blast reference name format:'
        print '\tgi|8928603|ref|NP_059408.1-NADH_dehydrogenase_subunit_7-Paramecium_aurelia'
        print '\tit is assumed that the top hits are listed first'
        print '\tit is also assumed that there is one ORF per geneID!!!'
        sys.exit(1)

    GTF = sys.argv[1]
    blast = sys.argv[2]

    doNHML = False
    if '-NoHitminLen' in sys.argv:
        doNHML = True
        NHML = int(sys.argv[sys.argv.index('-NoHitminLen') + 1])
        print 'will only retain ORFs without BLAST hits longer than', NHML

    InitialBlastDict = {}
    linelist=open(blast)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        ORF = fields[0]
        if InitialBlastDict.has_key(ORF):
            pass
        else:
            InitialBlastDict[ORF] = []
        name = fields[1].split('-')[1]
        if len(InitialBlastDict[ORF]) < 2:
            InitialBlastDict[ORF].append(name)

    BlastDict = {}
    for ORF in InitialBlastDict.keys():
        InitialBlastDict[ORF] = list(Set(InitialBlastDict[ORF]))
        InitialBlastDict[ORF].sort()
        if len(InitialBlastDict[ORF]) == 2:
            name = InitialBlastDict[ORF][0] + '/' + InitialBlastDict[ORF][1]
        else:
            name = InitialBlastDict[ORF][0]
        BlastDict[ORF] = name

    outfile = open(sys.argv[3],'w')

    NoHit = 0

    SeenDict = {}
    for ORF in BlastDict:
        name = BlastDict[ORF]
        if SeenDict.has_key(name):
            pass
        else:
            SeenDict[name] = 0
        SeenDict[name]+=1

    SeenTwice = {}

    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            outifle.write(line)
            continue
        fields = line.strip().split('\t')
        ORF = fields[8].split('gene_id "')[1].split('";')[0]
        if BlastDict.has_key(ORF):
            name = BlastDict[ORF]
            if SeenDict[name] > 1:
                if SeenTwice.has_key(name):
                    pass
                else:
                    SeenTwice[name] = 0
                SeenTwice[name]+=1
                name = name + '-' + str(SeenTwice[name])
            outline = line.strip().replace('"' + ORF + '"', '"' + name + '"')
            outfile.write(outline + '\n')
        else:
            if doNHML:
                left = int(fields[3])
                right = int(fields[4])
#                print right, left, right - left, 3*NHML, ORF
                if (right - left) < 3*NHML:
                    continue
            NoHit+=1
            outline = line.strip().replace('"' + ORF + '"', '"' + 'ORF-nohit-' + str(NoHit) + '"')
            outfile.write(outline + '\n')
	
    outfile.close()
   
run()
