##################################
#                                #
# Last modified 2025/04/17       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import gzip
from sets import Set

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s file1 chromField1 posField file2 chromField2 posField2 maxDist outfilename' % sys.argv[0]
        print "       Note: if files are in narrowPeak format, use 'narrowPeak' for the posField parameter"
        print "       Note: if files are in ERANGE hts format, use 9 for the posField parameter"
        print "       Note: if files are in simple bed file, use 'bed' for the posField parameter and the middle of the region will be used as peak"
        sys.exit(1)

    file1 = sys.argv[1]
    file2 = sys.argv[4]
    maxDist = int(sys.argv[7])
    outfilename = sys.argv[8]

    chromField1 = int(sys.argv[2])
    posField1 = sys.argv[3]
    doBED1 = False
    doNarrowPeak1 = False
    if posField1 == 'narrowPeak':
        doNarrowPeak1 = True
    elif posField1 == 'bed':
        doBED1 = True
    else:
        peakField1 = int(posField1) 

    chromField2 = int(sys.argv[5])
    posField2 = sys.argv[6]
    doBED2 = False
    doNarrowPeak2 = False
    if posField2 == 'narrowPeak':
        doNarrowPeak2 = True
    elif posField2 == 'bed':
        doBED2 = True
    else:
        peakField2 = int(posField2) 

    File2PeakDict={}

    if file2.endswith('.gz'):
        listoflines = gzip.open(file2)
    else:
        listoflines = open(file2)
    i=0
    for line in listoflines:
        i+=1
        if line.startswith('#') or line.startswith('track type'):
            continue
        fields=line.strip().split('\t')
        chr=fields[chromField2]
        left = int(fields[chromField2+1])
        if doBED2:
            right = int(fields[chromField2+2])
            peak = (left+right)/2
        elif doNarrowPeak2:
            chr=fields[0]
            left = int(fields[1])
            peak = left + int(fields[9])
        else:
            peak = int(fields[peakField2])
        if File2PeakDict.has_key(chr):
            pass
        else:
            File2PeakDict[chr]={}
        File2PeakDict[chr][peak] = fields

    outfile=open(outfilename,'w')

    if file1.endswith('.gz'):
        listoflines = gzip.open(file1)
    else:
        listoflines = open(file1)
    k=0
    for line in listoflines:
        k+=1
        if k % 10000 == 0:
            print k, 'lines processed'
        if line.startswith('#') or line.startswith('track type'):
            continue
        fields=line.strip().split('\t')
        chr=fields[chromField1]
        if File2PeakDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[chromField1+1])
        if doBED1:
            right = int(fields[chromField1+2])
            peak = (left+right)/2
        elif doNarrowPeak1:
            chr=fields[0]
            left = int(fields[1])
            peak = left + int(fields[9])
        else:
            peak = int(fields[peakField1])
        if File2PeakDict.has_key(chr):
            pass
        else:
            outline = line.strip() + '\tInf'
            outfile.write(outline + '\n')
            continue
#        NearestLeft = 100000000
#        NearestRight = 100000000
#        for i in range(peak,0,-1):
#            if File2PeakDict[chr].has_key(i):
#                NearestLeft = peak - i
#                break
#        for i in range(peak,NearestRight):
#            if File2PeakDict[chr].has_key(i):
#                NearestRight = i-peak
#                break
        if File2PeakDict[chr].has_key(peak):
            nearest = 0
            nearestPeak = peak
        else:
            peakslist = File2PeakDict[chr].keys() + [peak]
            peakslist.sort()
            peakindex = peakslist.index(peak)
            if peakindex == 0:
                nearest = peakslist[peakindex + 1] - peak
                nearestPeak = peakslist[peakindex + 1]
            elif peakindex == len(peakslist)-1:
                nearest = peak - peakslist[peakindex - 1]
                nearestPeak = peakslist[peakindex - 1]
            else:
                if peakslist[peakindex + 1] - peak < peak - peakslist[peakindex - 1]:
                    nearest = peakslist[peakindex + 1] - peak
                    nearestPeak = peakslist[peakindex + 1]
                else:
                    nearest = peak - peakslist[peakindex - 1]
                    nearestPeak = peakslist[peakindex - 1]
        if math.fabs(nearest) <= maxDist:
            outline = line.strip() + '\t' + str(nearest)
            peak2fields = File2PeakDict[chr][nearestPeak]
            for f in peak2fields:
                outline = outline + '\t' + f
            outfile.write(outline + '\n')

    outfile.close()

run()
