##################################
#                                #
# Last modified 08/30/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s peaks peakChrID peakFieldID gtf outfile' % sys.argv[0]
        print '       Note: if there is no peak field, substitue this entry with the word "bed" and the middle of the region will be taken'
        print '       Note: the GTF file has to be the output of Cufflinks with the expression estimates in the attributes fields'
        sys.exit(1)

    peaks = sys.argv[1]
    peakChrID = int(sys.argv[2])
    gtf = sys.argv[4]
    outfile = open(sys.argv[5],'w')

    doBed=False
    if sys.argv[3] == 'bed':
        doBed=True
    else:
        peakID = int(sys.argv[3])

    TSSDict={}
    TranscriptDict={}
    listoflines = open(gtf)
    for line in listoflines:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        if fields[2]!='transcript':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        gene = fields[8].split('gene_id "')[1].split('";')[0]
        transcript = fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]={}
            TranscriptDict[transcript]['gene']=gene
            TranscriptDict[transcript]['coordinates']=[]
            TranscriptDict[transcript]['chr']=chr
            TranscriptDict[transcript]['strand']=strand
        TranscriptDict[transcript]['coordinates'].append((left))
        TranscriptDict[transcript]['coordinates'].append((right))
        TranscriptDict[transcript]['FPKM'] = float(fields[8].split('FPKM "')[1].split('";')[0])
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr]={}

    print 'finished inputting GTF'

    for transcript in TranscriptDict.keys():
        if TranscriptDict[transcript]['strand']=='+':
            TSS = min(TranscriptDict[transcript]['coordinates'])
        if TranscriptDict[transcript]['strand']=='-':
            TSS = max(TranscriptDict[transcript]['coordinates'])
        strand=TranscriptDict[transcript]['strand']
        chr=TranscriptDict[transcript]['chr']
        gene=TranscriptDict[transcript]['gene']
        if TSSDict[chr].has_key((TSS,strand)):
            pass
        else:
            TSSDict[chr][(TSS,strand)]={}
            TSSDict[chr][(TSS,strand)]['genes']=[]
            TSSDict[chr][(TSS,strand)]['FPKM']=0
            TSSDict[chr][(TSS,strand)]['transcripts']=[]
        TSSDict[chr][(TSS,strand)]['FPKM'] += TranscriptDict[transcript]['FPKM']
        if gene not in TSSDict[chr][(TSS,strand)]['genes']:
            TSSDict[chr][(TSS,strand)]['genes'].append(gene)
        if transcript not in TSSDict[chr][(TSS,strand)]['transcripts']:
            TSSDict[chr][(TSS,strand)]['transcripts'].append(transcript)

    print 'finished parsing TSSs'

    outfile.write('#chr\tleft\tright\tTSS\tstrand\tgene(s)\ttranscript(s)\tDistanceToPeak\tFPKM:\n')

    listoflines = open(peaks)
    for line in listoflines:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[peakChrID]
        left = int(fields[peakChrID+1])
        right = int(fields[peakChrID+2])
        if doBed:
            peak = int((left + right)/2.0)
        else:
            peak=int(fields[peakID])
        if TSSDict.has_key(chr):
            pass
        else:
            continue
        Nearest=1000000000000000000000
        NearestTSS=''
        for (TSS,strand) in TSSDict[chr].keys():
            if math.fabs(TSS-peak) <= Nearest:
                Nearest=math.fabs(TSS-peak)
                NearestTSS=(TSS,strand)
        (TSS,strand) = NearestTSS
        outline = chr + '\t' + str(left) + '\t' + str(right) + '\t' + str(TSS) + '\t' + strand + '\t'
        for gene in TSSDict[chr][(TSS,strand)]['genes']:
            outline=outline+gene+','
        outline=outline[0:-1]+'\t'
        for transcript in TSSDict[chr][(TSS,strand)]['transcripts']:
            outline=outline+transcript+','
        outline=outline[0:-1]
        distance = TSS-peak
        if strand == '-':
            distance=(-1)*distance
        outline=outline+'\t'+str(distance)+'\t'+str(TSSDict[chr][(TSS,strand)]['FPKM'])
        outfile.write(outline + '\n')
        
    outfile.close()

run()
