##################################
#                                #
# Last modified 2018/08/27       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s gff upstream_bp downstream_bp outputfilename' % sys.argv[0]
        print 'Note:\tthe script will only work properly on files without alternative isoforms'
        sys.exit(1)
    
    GFF = sys.argv[1]
    upstreamBP = int(sys.argv[2])
    downstreamBP = int(sys.argv[3])
    outfilename = sys.argv[4]

    TSSDict = {}

    linelist = open(GFF)
    i=0
    for line in linelist:
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2] == 'gene':
            pass
        else:
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if 'Name=' in fields[8]:
            geneName = fields[8].split('Name=')[1].split(';')[0]
            geneID = fields[8].split('ID=')[1].split(';')[0]
        else:
            geneID = fields[8].split('ID=')[1].split(';')[0]
            geneName = geneID
        if strand == '+':
            TSS = (chr,left,strand)
        if strand == '-':
            TSS = (chr,right,strand)
        if TSSDict.has_key(TSS):
            pass
        else:
            TSSDict[TSS]={}
            TSSDict[TSS]['genes'] = []
        TSSDict[TSS]['genes'].append((geneName,geneID))
        
    outfile = open(outfilename,'w')
    outfile.write('#chr\tleft\tright\tstrand\tgeneName(s)\tgeneID(s)\n')

    keys = TSSDict.keys()
    keys.sort()
    for (chr,TSS,strand) in keys:
        TSSDict[(chr,TSS,strand)]['genes'] = list(Set(TSSDict[(chr,TSS,strand)]['genes']))
        if strand == '+':
            outline=chr + '\t' + str(max(0,TSS - upstreamBP)) + '\t' + str(TSS + downstreamBP) + '\t'+strand + '\t'
        if strand == '-':
            outline=chr + '\t' + str(max(0,TSS - downstreamBP)) + '\t' + str(TSS + upstreamBP) + '\t'+strand + '\t'
        for (geneName,geneID) in TSSDict[(chr,TSS,strand)]['genes']:
            outline = outline + geneName + ','
        outline=outline[0:-1] + '\t'
        for (geneName,geneID) in TSSDict[(chr,TSS,strand)]['genes']:
            outline = outline + geneID + ','
        outline=outline[0:-1] + '\t'
        outfile.write(outline+'\n')
   
    outfile.close()
   
run()
