##################################
#                                #
# Last modified 2024/04/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import gzip
import numpy as np

def run():

    if len(sys.argv) < 14:
        print 'usage: python %s juicer_tools_location observed/oe NONE/VC/VC_SQRT/KR .hic_file window radius positions_file_1 chrFieldID_1 posFieldID_1|narrowPeak positions_file_2 chrFieldID_2 posFieldID_2|narrowPeak maxDist outfilename' % sys.argv[0]
        sys.exit(1)

    juicer = sys.argv[1]
    OE = sys.argv[2]
    KR = sys.argv[3]
    HIC = sys.argv[4]
    window = int(sys.argv[5])
    radius = int(sys.argv[6])
    positions1 = sys.argv[7]
    chrFieldID1 = int(sys.argv[8])
    posFieldID1 = sys.argv[9]
    positions2 = sys.argv[10]
    chrFieldID2 = int(sys.argv[11])
    posFieldID2 = sys.argv[12]
    maxDist = int(sys.argv[13])
    outprefix = sys.argv[14]

    PosDict1 = {}
    PosDict2 = {}

    if positions1.endswith('.gz'):
        lineslist = gzip.open(positions1)
    else:
        lineslist = open(positions1)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID1]
        if posFieldID1 == 'narrowPeak':
            pos = int(fields[1]) + int(fields[9])
        else:
            pos = int(fields[int(posFieldID1)])
        pos = pos - (pos % window)
        if PosDict1.has_key(chr):
            pass
        else:
            PosDict1[chr] = {}
        PosDict1[chr][pos] = 1

    if positions2.endswith('.gz'):
        lineslist = gzip.open(positions2)
    else:
        lineslist = open(positions2)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID2]
        if posFieldID2 == 'narrowPeak':
            pos = int(fields[1]) + int(fields[9])
        else:
            pos = int(fields[int(posFieldID2)])
        pos = pos - (pos % window)
        if PosDict2.has_key(chr):
            pass
        else:
            PosDict2[chr] = {}
        PosDict2[chr][pos] = 1

    PosPairDict = {}
    TotalPositions = 0.0
    for chr in PosDict2.keys():
        if PosDict1.has_key(chr):
            pass
        else:
            continue
        PosPairDict[chr] = {}
        for pos1 in PosDict1[chr]:
            for pos2 in PosDict2[chr]:
                if math.fabs(pos1-pos2) < maxDist:
                    PosPairDict[chr][(pos1,pos2)] = 1
                    TotalPositions += 1

    print 'Total Position Pairs', TotalPositions

    DataMatrix = {}
    for i in range(-radius,radius,window):
        DataMatrix[i] = {}
        for j in range(-radius,radius,window):
            DataMatrix[i][j] = 0

    for chr in PosPairDict.keys():
        tempDataMatrix = {}
        cmd = 'java -jar ' + juicer + ' dump ' + OE + ' ' + KR + ' ' + HIC + ' ' + chr + ' ' + chr + ' BP ' + str(window) + ' ' + outprefix + '.temp'
        print chr
        os.system(cmd)
        linelist = open(outprefix + '.temp')
        for line in linelist:
            if line.startswith('#'):
               continue
            fields = line.strip().split('\t')
            pos1 = int(fields[0])
            pos2 = int(fields[1])
            pos1 = pos1
            pos2 = pos2
            score = float(fields[2])
            if tempDataMatrix.has_key(pos1):
                pass
            else:
                tempDataMatrix[pos1] = {}
            if tempDataMatrix.has_key(pos2):
                pass
            else:
                tempDataMatrix[pos2] = {}
            tempDataMatrix[pos1][pos2] = score
            tempDataMatrix[pos2][pos1] = score
        cmd = 'rm ' + outprefix + '.temp'
        os.system(cmd)

        for (pos1,pos2) in PosPairDict[chr].keys():
            for i in range(pos1-radius,pos1+radius,window):
                for j in range(pos2-radius,pos2+radius,window):
                    if tempDataMatrix.has_key(i):
                        if tempDataMatrix[i].has_key(j):
                            DataMatrix[i-pos1][j-pos2] += tempDataMatrix[i][j]
                            DataMatrix[j-pos2][i-pos1] += tempDataMatrix[j][i]
             
    outfile = open(outprefix + '.matrix', 'w')

    outline = '#'
    for i in range(-radius,radius,window):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')
    for i in range(-radius,radius,window):
        outline = str(i)
        for j in range(-radius,radius,window):
            outline = outline + '\t' + str(DataMatrix[i][j]/TotalPositions)
        outfile.write(outline.replace('nan','0') + '\n')

    outfile.close()
   
run()

