##################################
#                                #
# Last modified 2025/05/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
import numpy as np
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s SingleMoleculeCorrelation-NMI-matrix-C-BAM.bed.bgz motifs chrFieldID posFieldID strandFieldID radius outfilename [-bed]' % sys.argv[0]
        print '\Note: it is assumed that only a single chromosome is in the input file'
        sys.exit(1)

    input = sys.argv[1]
    BED = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    posFieldID = int(sys.argv[4])
    strandFieldID = int(sys.argv[5])
    radius = int(sys.argv[6])
    outfilename = sys.argv[7]

    doBed = False
    if '-bed' in sys.argv:
        doBed = True

    Regions = []
    WantedDict = {}

    if BED.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + BED
    elif BED.endswith('.gz') or BED.endswith('.bgz'):
        cmd = 'zcat ' + BED
    elif BED.endswith('.zip'):
        cmd = 'unzip -p ' + BED
    else:
        cmd = 'cat ' + BED
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        strand = fields[strandFieldID]
        if doBed:
            pos = int((int(fields[posFieldID]) + int(fields[posFieldID + 1]))/2)
        else:
            pos = int(fields[posFieldID])
        Regions.append((chr,pos-radius,pos+radius,strand))
        if WantedDict.has_key(chr):
            pass
        else:
            WantedDict[chr] = {}
        for i in range(pos-radius,pos+radius):
            WantedDict[chr][i] = ''

    print 'finished inputting regions'

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz') or input.endswith('.bgz'):
        cmd = 'zcat ' + input
    elif input.endswith('.zip'):
        cmd = 'unzip -p ' + input
    else:
        cmd = 'cat ' + input
    P = os.popen(cmd, "r")
    LK = 0
    line = 'line'
    while line != '':
        line = P.readline().strip()
        LK += 1
        if LK % 1000000 == 0:
            print str(LK/1000000) + 'M lines processed'
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if WantedDict.has_key(chr):
            pass
        else:
            continue
        pos = int(fields[1])
        if WantedDict[chr].has_key(pos):
            pass
        else:
            continue
        WantedDict[chr][pos] = fields[3]

    print 'finished inputting scores'

    Matrix = {}
    for (chr,left,right,strand) in Regions:
        for pos in range(left,right):
            scores = WantedDict[chr][pos]
            if scores == '':
                continue
            else:
                for s in scores.split(';'):
                    relpos = int(s.split(':')[0])
                    NMI = float(s.split(':')[1])
                    if pos + relpos <= right:
                        if strand == '+':
                            i = pos-left
                            j = i+relpos
                            if Matrix.has_key(i-radius):
                                pass
                            else:
                                Matrix[i-radius] = {}
                            if Matrix[i-radius].has_key(j-radius):
                                pass
                            else:
                                Matrix[i-radius][j-radius] = []
                            Matrix[i-radius][j-radius].append(NMI)
                        if strand == '-':
                            i = right - pos
                            j = right - (i+relpos)
                            if Matrix.has_key(j-radius):
                                pass
                            else:
                                Matrix[j-radius] = {}
                            if Matrix[j-radius].has_key(i-radius):
                                pass
                            else:
                                Matrix[j-radius][i-radius] = []
                            Matrix[j-radius][i-radius].append(NMI)

    outfile = open(outfilename, 'w')

    outline = '#'
    for i in range(-radius,radius):
        outline + '\t' + str(i)
    outfile.write(outline +'\n')

    for i in range(-radius,radius):
        outline = str(i)
        for j in range(-radius,radius):
            if Matrix.has_key(i) and Matrix[i].has_key(j):
                if len(Matrix[i][j]) == 0:
                    outline += '\tnan'
                else:
                    NMI = np.mean(Matrix[i][j])
                    outline += '\t' + str(NMI)
            else:
                outline += '\tnan'
        outfile.write(outline +'\n')

    outfile.close()
            
run()

