##################################
#                                #
# Last modified 2018/01/19       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s TSS.bed CAGE_peaks outfile' % sys.argv[0]
        print '\tAssumed CAGE format:'
        print '\tchr1	634003	634027	.	1000	+	4.19	4.80	634003	634026	1422.00000	634005	634027	1420.00000'
        print '\tUse a TSS bed file with the desired radius of overlap, in the following format'
        print '\tchr1	35873	36273	-	FAM138A	ENSG00000237613.2	ENST00000461467.1'
        sys.exit(1)

    TSS = sys.argv[1]
    CAGE = sys.argv[2]
    outfile = open(sys.argv[3],'w')

    CAGEpeaksDict = {}
    linelist = open(CAGE)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[5]
        score1 = float(fields[10])
        score2 = float(fields[13])
        if CAGEpeaksDict.has_key(chr):
            pass
        else:
            CAGEpeaksDict[chr]={}
            CAGEpeaksDict[chr]['+'] = {}
            CAGEpeaksDict[chr]['-'] = {}
        if strand == '+':
            CAGEpeaksDict[chr][strand][left] = (right,score1,score2)
        if strand == '-':
            CAGEpeaksDict[chr][strand][right] = (left,score1,score2)

    outline = '#chr\tleft\tright\tstrand\tgeneName\tgeneID\ttranscriptID\tCAGE_peak_left\tCAGE_peak_right\tCAGE_peak_signal1\tCAGE_peak_signal2\tCAGE_peak_distance'
    outfile.write(outline + '\n')

    linelist = open(TSS)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        TSS = (left + right)/2.0
        strand = fields[3]
        overlappingClusters = []
        if CAGEpeaksDict.has_key(chr) and CAGEpeaksDict[chr].has_key(strand):
            for i in range(left,right):
                if CAGEpeaksDict[chr][strand].has_key(i):
                    CAGE1 = i
                    (CAGE2,score1,score2) = CAGEpeaksDict[chr][strand][i]
                    overlappingClusters.append((min(CAGE1,CAGE2),max(CAGE1,CAGE2),score1,score2))
            if len(overlappingClusters) == 0:
                outline = line.strip() + '\tnan\tnan\tnan\tnan\tnan'
            else:
                minDist = 1000000000000
                minDistCAGE = ''
                for (CAGELeft,CAGERight,score1,score2) in overlappingClusters:
                    CAGEpeak = (CAGELeft + CAGERight)/2.0
                    distance = math.fabs(CAGEpeak - TSS)
                    if distance < minDist:
                        minDistCAGE = (CAGELeft,CAGERight,score1,score2)
                outline = line.strip() + '\t' + str(CAGELeft) + '\t' + str(CAGERight) + '\t' + str(score1) + '\t' + str(score2)
                if strand == '+':
                    outline = outline + '\t' + str(CAGEpeak - TSS)
                if strand == '-':
                    outline = outline + '\t' + str(TSS - CAGEpeak)
        else:
            outline = line.strip() + '\tnan\tnan\tnan\tnan\tnan'
        outfile.write(outline + '\n')

    outfile.close()
   
run()
