##################################
#                                #
# Last modified 2024/04/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s juicer_tools_location observed/oe NONE/VC/VC_SQRT/KR .hic_file window radius positions_file chrFieldID posFieldID|narrowPeak outfilename' % sys.argv[0]
        sys.exit(1)

    juicer = sys.argv[1]
    OE = sys.argv[2]
    KR = sys.argv[3]
    HIC = sys.argv[4]
    window = int(sys.argv[5])
    radius = int(sys.argv[6])
    positions = sys.argv[7]
    chrFieldID = int(sys.argv[8])
    posFieldID = sys.argv[9]
    outprefix = sys.argv[10]

    PosDict = {}

    TotalPositions = 0.0

    lineslist = open(positions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        if posFieldID == 'narrowPeak':
            pos = int(fields[1]) + int(fields[9])
        else:
            pos = int(fields[int(posFieldID)])
        pos = pos - (pos % window)
        if PosDict.has_key(chr):
            pass
        else:
            PosDict[chr] = {}
        PosDict[chr][pos] = 1
        TotalPositions += 1

    print 'TotalPositions', TotalPositions

    DataMatrix = {}
    for i in range(-radius,radius,window):
        DataMatrix[i] = {}
        for j in range(-radius,radius,window):
            DataMatrix[i][j] = 0

    for chr in PosDict.keys():
        tempDataMatrix = {}
        cmd = 'java -jar ' + juicer + ' dump ' + OE + ' ' + KR + ' ' + HIC + ' ' + chr + ' ' + chr + ' BP ' + str(window) + ' ' + outprefix + '.temp'
        print chr
        os.system(cmd)
        linelist = open(outprefix + '.temp')
        for line in linelist:
            if line.startswith('#'):
               continue
            fields = line.strip().split('\t')
            pos1 = int(fields[0])
            pos2 = int(fields[1])
            pos1 = pos1
            pos2 = pos2
            score = float(fields[2])
            if tempDataMatrix.has_key(pos1):
                pass
            else:
                tempDataMatrix[pos1] = {}
            if tempDataMatrix.has_key(pos2):
                pass
            else:
                tempDataMatrix[pos2] = {}
            tempDataMatrix[pos1][pos2] = score
            tempDataMatrix[pos2][pos1] = score
        cmd = 'rm ' + outprefix + '.temp'
        os.system(cmd)

        for pos in PosDict[chr].keys():
            for i in range(pos-radius,pos+radius,window):
                for j in range(pos-radius,pos+radius,window):
#                    print pos, i, j, i-pos, j-pos, tempDataMatrix[i][j]
                    if tempDataMatrix.has_key(i):
                        if tempDataMatrix[i].has_key(j):
                            DataMatrix[i-pos][j-pos] += tempDataMatrix[i][j]
                            DataMatrix[j-pos][i-pos] += tempDataMatrix[j][i]
             
    outfile = open(outprefix + '.matrix', 'w')

    outline = '#'
    for i in range(-radius,radius,window):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')
    for i in range(-radius,radius,window):
        outline = str(i)
        for j in range(-radius,radius,window):
            outline = outline + '\t' + str(DataMatrix[i][j]/TotalPositions)
        outfile.write(outline.replace('nan','0') + '\n')

    outfile.close()
   
run()
