##################################
#                                #
# Last modified 2025/06/06       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import numpy as np
import gzip
from sets import Set

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s focus_motifs_file chrFieldID strandFieldID motifFieldID generalMotifFile chrFieldID strandFieldID motifFieldID minCount radius outfile' % sys.argv[0]
        print '\tNote: peak_list format: label <tab> file <tab> chrFieldID'
        print '\tNote: a tabix-index motif file is assumed'
        sys.exit(1)

    motifsFocus = sys.argv[1]
    motifsFocusChrFieldID = int(sys.argv[2])
    motifsFocusStrandFieldID = int(sys.argv[3])
    motifsFocusMotifFieldID = int(sys.argv[4])
    motifsGeneral = sys.argv[5]
    motifsGeneralChrFieldID = int(sys.argv[6])
    motifsGeneralStrandFieldID = int(sys.argv[7])
    motifsGeneralMotifFieldID = int(sys.argv[8])
    minCount = int(sys.argv[9])
    radius = int(sys.argv[10])
    outprefix = sys.argv[11]

    FocusMotifDict = {}

    if motifsFocus.endswith('gz'):
        peaksfiles = gzip.open(motifsFocus)
    else:
        peaksfiles = open(motifsFocus)
    for peakline in peaksfiles:
        fields = peakline .strip().split('\t')
        chr = fields[motifsFocusChrFieldID]
        left = int(fields[motifsFocusChrFieldID + 1])
        right = int(fields[motifsFocusChrFieldID + 2])
        strand = fields[motifsFocusStrandFieldID]
        motif = fields[motifsFocusMotifFieldID]
        if FocusMotifDict.has_key(chr):
            pass
        else:
            FocusMotifDict[chr] = []
        FocusMotifDict[chr].append((left,right,strand,motif))

    GeneralMotifDict = {}
    MotLenDict = {}

    if motifsGeneral.endswith('gz'):
        lineslist = gzip.open(motifsGeneral)
    else:
        lineslist = open(motifsGeneral)
    for line in lineslist:
        fields = line.strip().split('\t')
        chr = fields[motifsGeneralChrFieldID]
        if FocusMotifDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[motifsGeneralChrFieldID + 1])
        right = int(fields[motifsGeneralChrFieldID + 2])
        strand = fields[motifsGeneralStrandFieldID]
        motif = fields[motifsGeneralMotifFieldID]
        MotLenDict[motif] = right-left
        if GeneralMotifDict.has_key(chr):
            pass
        else:
            GeneralMotifDict[chr] = {}
        if GeneralMotifDict[chr].has_key(motif):
            pass
        else:
            GeneralMotifDict[chr][motif] = {}
        GeneralMotifDict[chr][motif][left] = strand

    MotifPairsDict = {}

    for chr in FocusMotifDict.keys():
        for (left,right,strand1,motif1) in FocusMotifDict[chr]:
            if MotifPairsDict.has_key(motif1):
                pass
            else:
                MotifPairsDict[motif1] = {}
            pos = int((left + right)/2)
            if strand1 == '+':
                for i in range(pos-radius,pos + radius):
                    for motif2 in GeneralMotifDict[chr].keys():
                        if GeneralMotifDict[chr][motif2].has_key(i):
                            strand2 = GeneralMotifDict[chr][motif2][i]
                            mot2Len = MotLenDict[motif2]
                            if (i > right and i + mot2Len < pos + radius):
                                distance = i-right
                                if strand2 == '+':
                                    orientation = 'tandem'
                                if strand2 == '-':
                                    orientation = 'convergent'
                                if MotifPairsDict[motif1].has_key((motif2,distance,orientation)):
                                    pass
                                else:
                                    MotifPairsDict[motif1][(motif2,distance,orientation)] = []
                                MotifPairsDict[motif1][(motif2,distance,orientation)].append((chr,left,right,strand1,motif1))
            if strand1 == '-':
                for i in range(pos - radius,pos + radius):
                    for motif2 in GeneralMotifDict[chr].keys():
                        if GeneralMotifDict[chr][motif2].has_key(i):
                            strand2 = GeneralMotifDict[chr][motif2][i]
                            mot2Len = MotLenDict[motif2]
                            if (i > pos - radius and i + mot2Len < left):
                                distance = left - (i + mot2Len)
                                if strand2 == '-':
                                    orientation = 'tandem'
                                if strand2 == '+':
                                    continue
                                if MotifPairsDict[motif1].has_key((motif2,distance,orientation)):
                                    pass
                                else:
                                    MotifPairsDict[motif1][(motif2,distance,orientation)] = []
                                MotifPairsDict[motif1][(motif2,distance,orientation)].append((chr,left,right,strand1,motif1))
                            if (i > right and i + mot2Len < pos + radius):
                                distance = i - right
                                if strand2 == '+':
                                    orientation = 'divergent'
                                if strand2 == '-':
                                    continue
                                if MotifPairsDict[motif1].has_key((motif2,distance,orientation)):
                                    pass
                                else:
                                    MotifPairsDict[motif1][(motif2,distance,orientation)] = []
                                MotifPairsDict[motif1][(motif2,distance,orientation)].append((chr,left,right,strand1,motif1))

    for motif1 in MotifPairsDict.keys():
        M2s = MotifPairsDict[motif1].keys()
        M2s.sort()
        for (motif2,distance,orientation) in M2s:
            print motif1, motif2, distance, orientation, len(MotifPairsDict[motif1][(motif2,distance,orientation)])
            if len(MotifPairsDict[motif1][(motif2,distance,orientation)]) < minCount:
                continue
            outfile = open(outprefix + '.' + motif1 + '-' + motif2 + '-' + orientation + '-' + str(distance) + 'bp_apart' + '.n_' + str(len(MotifPairsDict[motif1][(motif2,distance,orientation)])), 'w')
            MotifPairsDict[motif1][(motif2,distance,orientation)].sort()
            for (chr,left,right,strand1,motif1) in MotifPairsDict[motif1][(motif2,distance,orientation)]:
                outline = chr + '\t' + str(left) +  '\t' + str(right) + '\t' + strand1 + '\t' + motif1 + '\t' + motif2 + '\t' + str(distance) + '\t' + orientation
                outfile.write(outline + '\n')
            outfile.close()

run()
