##################################
#                                #
# Last modified 10/20/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s GTF <target genes file> <chr field ID> <transcript field ID> <number genes on each side of transcript> >noutfilename' % sys.argv[0]
        print '\tNote: if the chr field is in the form of chr:left-right, the script will split by the ";" sign'
        sys.exit(1)

    GTF = sys.argv[1]
    inputfilename = sys.argv[2]
    chrfieldID = int(sys.argv[3])
    fieldID = int(sys.argv[4])
    N = int(sys.argv[5])
    outputfilename = sys.argv[6]

    WantedTranscriptDict={}
    listoflines = open(inputfilename)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcriptID = fields[fieldID]
        WantedTranscriptDict[transcriptID]={}

    print 'fisnished inputting target gene list'

    GenePlusWantedTranscriptDict = {}
    listoflines = open(GTF)
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        geneID=fields[8].split('gene_id "')[1].split('"')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('"')[0]
        if WantedTranscriptDict.has_key(transcriptID):
            ID = transcriptID
        else:
            if 'gene_name "' in fields[8]:
                geneName=fields[8].split('gene_name "')[1].split('"')[0]
            else:
                geneName = geneID
            ID = (geneID,geneName)
        if GenePlusWantedTranscriptDict.has_key(chr):
            pass
        else:
            GenePlusWantedTranscriptDict[chr]={}
        if GenePlusWantedTranscriptDict[chr].has_key(ID):
            pass
        else:
            GenePlusWantedTranscriptDict[chr][ID]=[]
        GenePlusWantedTranscriptDict[chr][ID].append(left)
        GenePlusWantedTranscriptDict[chr][ID].append(right)

    print 'fisnished inputting GTF'

    PositionOnChrDict = {}

    for chr in GenePlusWantedTranscriptDict.keys():
        PositionOnChrDict[chr]=[]
        for ID in GenePlusWantedTranscriptDict[chr].keys():
            left = min(GenePlusWantedTranscriptDict[chr][ID])
            right = max(GenePlusWantedTranscriptDict[chr][ID])
            PositionOnChrDict[chr].append((left,right,ID))

    print 'fisnished finalizing gene coordinates'

    for chr in GenePlusWantedTranscriptDict.keys():
        PositionOnChrDict[chr].sort()

    print 'fisnished sorting genes'

    outfile = open(outputfilename, 'w')

    print len(WantedTranscriptDict.keys())

    listoflines = open(inputfilename)
    for line in listoflines:
        if line.startswith('#'):
            outline = line.strip()
            for i in range(-N,0):
                outline = outline + '\t' + str(i) + '_gene'
            for i in range(1,N+1):
                outline = outline + '\t+' + str(i) + '_gene'
            outfile.write(outline + '\n')
            continue
        fields=line.strip().split('\t')
        transcriptID = fields[fieldID]
        chr = fields[chrfieldID].split(':')[0]
        left = min(GenePlusWantedTranscriptDict[chr][transcriptID])
        right = max(GenePlusWantedTranscriptDict[chr][transcriptID])
        position = PositionOnChrDict[chr].index((left,right,transcriptID))
        NearestGenes = {}
        for i in range(-N,0):
            NearestGenes[i] = '-'
        for i in range(1,N+1):
            NearestGenes[i] = '-'
        i = 1
        j = 1
        while i < N+1:
            if position - j >= 0:
                (gleft,gright,ID) = PositionOnChrDict[chr][position - j]
                if WantedTranscriptDict.has_key(ID):
                    j+=1
                    continue
                else:
                    (geneID,geneName) = ID
                    NearestGenes[-i] = geneName
                    j+=1
                    i+=1
            else:
                i=N+1
        i = 1
        j = 1
        while i < N+1:
            if position + j < len(PositionOnChrDict[chr]):
                (gleft,gright,ID) = PositionOnChrDict[chr][position + j]
                if WantedTranscriptDict.has_key(ID):
                    j+=1
                    continue
                else:
                    (geneID,geneName) = ID
                    NearestGenes[i] = geneName
                    j+=1
                    i+=1
            else:
                i=N+1
        outline = line.strip()
        for i in range(-N,0):
            outline = outline + '\t' + NearestGenes[i]
        for i in range(1,N+1):
            outline = outline + '\t' + NearestGenes[i]
        outfile.write(outline + '\n')

    outfile.close()

run()

