##################################
#                                #
# Last modified 2018/07/06       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s SingleMoleculeCorrelation_output HiC-matrix window(bp) outfile' % sys.argv[0]
        print 'Note: the script assumes that the input is the output from SingleMoleculeCorrelation.py, i.e. regions are sorted by coordinates!!!'
        print 'Note: the script assumes that the HiC matrix is the output of the InteractionMatrix.py; it can be zipped in a variety of formats'
        sys.exit(1)

    SMC = sys.argv[1]
    HiC = sys.argv[2]
    W = int(sys.argv[3])
    outfilename = sys.argv[4]

    print 'finished inputting peaks'

    outfile = open(outfilename,'w')

    SMCDict = {}
    if SMC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + SMC
    elif SMC.endswith('.gz') or SMC.endswith('.bgz'):
        cmd = 'zcat ' + SMC
    elif SMC.endswith('.zip'):
        cmd = 'unzip -p ' + SMC
    else:
        cmd = 'cat ' + SMC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        L1 = int(fields[1])
        R1 = int(fields[2])
        L2 = int(fields[6])
        R2 = int(fields[7])
        L1 = L1 - L1 % W
        L2 = L2 - L2 % W
        R1 = R1 - R1 % W
        R2 = R2 - R2 % W
        if SMCDict.has_key(chr):
            pass
        else:
            SMCDict[chr] = {}
        for i in range(L1,R1 + W,W):
            for j in range(L2,R2 + W,W):
                SMCDict[chr][(i,j)] = 0

    print SMCDict.keys()

    LN = 0
    if HiC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + HiC
    elif HiC.endswith('.gz') or HiC.endswith('.bgz'):
        cmd = 'zcat ' + HiC
    elif HiC.endswith('.zip'):
        cmd = 'unzip -p ' + HiC
    else:
        cmd = 'cat ' + HiC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        LN += 1
        if LN % 1000000 == 0:
            print str(LN/1000000) + 'M lines processed'
        fields = line.strip().split('\t')
        chr = fields[0]
        chr2 = fields[2]
        if SMCDict.has_key(chr):
            pass
        else:
            continue
        if chr != chr2:
            continue
        L = int(fields[1])
        R= int(fields[3])
        L = L - L % W
        R = R - R % W
        score = float(fields[4])
        if SMCDict[chr].has_key((L,R)):
            SMCDict[chr][(L,R)] += score

    print 'finished inputting Hi-C matrix'

    if SMC.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + SMC
    elif SMC.endswith('.gz') or SMC.endswith('.bgz'):
        cmd = 'zcat ' + SMC
    elif SMC.endswith('.zip'):
        cmd = 'unzip -p ' + SMC
    else:
        cmd = 'cat ' + SMC
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'HiC_interactions'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        L1 = int(fields[1])
        R1 = int(fields[2])
        L2 = int(fields[6])
        R2 = int(fields[7])
        L1 = L1 - L1 % W
        L2 = L2 - L2 % W
        R1 = R1 - R1 % W
        R2 = R2 - R2 % W
        score = 0
        for i in range(L1,R1 + W,W):
            for j in range(L2,R2 + W,W):
                score += SMCDict[chr][(i,j)]
        outline = line.strip() + '\t' + str(score)
        outfile.write(outline + '\n')

    outfile.close()
            
run()

