##################################
#                                #
# Last modified 202/04/07        # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s peaks peakChrID peakFieldID gtf outfile [-narrowPeak]' % sys.argv[0]
        print '\t if the peak is not explicitly specified and you want the midpoint of the region, enter "midPoint" for the peakFieldID; it is assumed the left and right coordinates immediately follow the chromosome field'
        sys.exit(1)

    peaks = sys.argv[1]
    peakChrID = int(sys.argv[2])
    if sys.argv[3] == 'midPoint':
        peakID = 'midPoint'
    else:
        peakID = int(sys.argv[3])
    gtf = sys.argv[4]
    outfile = open(sys.argv[5],'w')

    doNarrowPeak = False
    if '-narrowPeak' in sys.argv:
        doNarrowPeak = True

    TSSDict={}
    TranscriptDict={}
    listoflines = open(gtf)
    for line in listoflines:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        gene = fields[8].split('gene_id "')[1].split('";')[0]
        transcript = fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]={}
            TranscriptDict[transcript]['gene']=gene
            TranscriptDict[transcript]['coordinates']=[]
            TranscriptDict[transcript]['chr']=chr
            TranscriptDict[transcript]['strand']=strand
        TranscriptDict[transcript]['coordinates'].append((left))
        TranscriptDict[transcript]['coordinates'].append((right))
        if 'gene_name' in fields[8]:
            TranscriptDict[transcript]['name'] = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            TranscriptDict[transcript]['name'] = gene
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr]={}

    for transcript in TranscriptDict.keys():
        strand=TranscriptDict[transcript]['strand']
        chr=TranscriptDict[transcript]['chr']
        gene=TranscriptDict[transcript]['gene']
        name=TranscriptDict[transcript]['name']
        if TranscriptDict[transcript]['strand']=='.':
            TSS = min(TranscriptDict[transcript]['coordinates'])
            if TSSDict[chr].has_key((TSS,strand)):
                pass
            else:
                TSSDict[chr][(TSS,strand)]={}
                TSSDict[chr][(TSS,strand)]['genes']=[]
                TSSDict[chr][(TSS,strand)]['transcripts']=[]
            if (gene,name) not in TSSDict[chr][(TSS,strand)]['genes']:
                TSSDict[chr][(TSS,strand)]['genes'].append((gene,name))
            if transcript not in TSSDict[chr][(TSS,strand)]['transcripts']:
                TSSDict[chr][(TSS,strand)]['transcripts'].append(transcript)
            TSS = max(TranscriptDict[transcript]['coordinates'])
            if TSSDict[chr].has_key((TSS,strand)):
                pass
            else:
                TSSDict[chr][(TSS,strand)]={}
                TSSDict[chr][(TSS,strand)]['genes']=[]
                TSSDict[chr][(TSS,strand)]['transcripts']=[]
            if (gene,name) not in TSSDict[chr][(TSS,strand)]['genes']:
                TSSDict[chr][(TSS,strand)]['genes'].append((gene,name))
            if transcript not in TSSDict[chr][(TSS,strand)]['transcripts']:
                TSSDict[chr][(TSS,strand)]['transcripts'].append(transcript)
        else:
            if TranscriptDict[transcript]['strand']=='+':
                TSS = min(TranscriptDict[transcript]['coordinates'])
            if TranscriptDict[transcript]['strand']=='-':
                TSS = max(TranscriptDict[transcript]['coordinates'])
            if TSSDict[chr].has_key((TSS,strand)):
                pass
            else:
                TSSDict[chr][(TSS,strand)]={}
                TSSDict[chr][(TSS,strand)]['genes']=[]
                TSSDict[chr][(TSS,strand)]['transcripts']=[]
            if (gene,name) not in TSSDict[chr][(TSS,strand)]['genes']:
                TSSDict[chr][(TSS,strand)]['genes'].append((gene,name))
            if transcript not in TSSDict[chr][(TSS,strand)]['transcripts']:
                TSSDict[chr][(TSS,strand)]['transcripts'].append(transcript)

    outfile.write('#chr\tTSS\tstrand\tgene_id\tgene_name\ttranscript(s)\tDistanceToNearestPeak\tRegionFields:\n')

    Seen={}
    listoflines = open(peaks)
    PeakDict={}
    for line in listoflines:
        if line[0]=='#':
            continue
        if line.strip() == '':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[peakChrID]
        if Seen.has_key(chr):
            pass
        else:
            Seen[chr]=''
            print chr
        if doNarrowPeak:
            peak=int(fields[1]) + int(fields[9])
        else:
            if peakID == 'midPoint':
                 right = int(fields[peakChrID + 2])
                 left = int(fields[peakChrID + 1])
                 peak = int((left + right)/2.)
            else:
                 peak=int(fields[peakID])
        if TSSDict.has_key(chr):
            pass
        else:
            continue
        Nearest = 1000000000000000000000
        NearestTSS=''
        if len(TSSDict[chr].keys()) == 0:
            print 'no TSSs found for contig', chr, 'skipping'
            continue
        for (TSS,strand) in TSSDict[chr].keys():
            if math.fabs(TSS-peak) <= Nearest:
                Nearest = math.fabs(TSS-peak)
                NearestTSS = (TSS,strand)
        (TSS,strand) = NearestTSS
        outline = chr + '\t' + str(TSS) + '\t' + strand + '\t'
        for (gene,name) in TSSDict[chr][(TSS,strand)]['genes']:
            outline=outline+gene+','
        outline=outline[0:-1]+'\t'
        for (gene,name) in TSSDict[chr][(TSS,strand)]['genes']:
            outline=outline+name+','
        outline=outline[0:-1]+'\t'
        for transcript in TSSDict[chr][(TSS,strand)]['transcripts']:
            outline=outline+transcript+','
        outline=outline[0:-1]
        distance = peak - TSS
        if strand == '+' and distance > 0:
            distance = distance
        if strand == '+' and distance < 0:
            distance = distance
        if strand == '-' and distance > 0:
            distance = (-1)*distance
        if strand == '-' and distance < 0:
            distance = (-1)*distance
        outline = outline + '\t' + str(distance) + '\t' + line
        outfile.write(outline)
        
    outfile.close()

run()
