##################################
#                                #
# Last modified 2025/01/13       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import h5py
import string
import numpy as np

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s modisco_results.h5 outprefix' % sys.argv[0]
        sys.exit(1)

    MODISCO = sys.argv[1]
    outprefix = sys.argv[2]

    TEMPERATURE = 1

    fMfile = h5py.File(MODISCO, 'r')

    print fMfile

    groupC = fMfile['pos_patterns']
    for pattern in groupC.keys():
#    for pattern in groupC.itervalues():
        groupD = fMfile['pos_patterns'][pattern]
        for subpattern in groupD.keys():
            if subpattern.startswith('subpattern_'):
                pass
            else:
                continue
            PFM = []
            ACGTs = groupC[pattern][subpattern]['sequence'][:]
            print pattern, subpattern
            N =  groupC[pattern]['seqlets']['n_seqlets'][:][0]
            NS =  groupC[pattern][subpattern]['seqlets']['n_seqlets'][:][0]
            outfile = open(outprefix + '.' + pattern + '.n_' + str(N) + '_' + subpattern + '.ns_' + str(NS) + '.PWM.meme','w')
            print N, outprefix + '.' + pattern + '.n_' + str(N) + '_' + subpattern + '.ns_' + str(NS) + '.meme'
            M = ''
            for acgt in ACGTs:
                (a,c,g,t) = (acgt[0],acgt[1],acgt[2],acgt[3])
                PFM.append((a,c,g,t))
                if a >= 0.60:
                    M += 'A'
                elif c >= 0.60:
                    M += 'C'
                elif g >= 0.60:
                    M += 'G'
                elif t >= 0.60:
                    M += 'T'
                elif a >= 0.30 and c >= 0.30 and g >= 0.30:
                    M += 'V'
                elif a >= 0.30 and c >= 0.30 and t >= 0.30:
                    M += 'H'
                elif a >= 0.30 and g >= 0.30 and t >= 0.30:
                    M += 'D'
                elif c >= 0.30 and g >= 0.30 and t >= 0.30:
                    M += 'B'
                elif a >= 0.35 and c >= 0.35:
                    M += 'M'
                elif g >= 0.35 and t >= 0.35:
                    M += 'K'
                elif a >= 0.35 and t >= 0.35:
                    M += 'W'
                elif g >= 0.35 and c >= 0.35:
                    M += 'S'
                elif c >= 0.35 and t >= 0.35:
                    M += 'Y'
                elif a >= 0.35 and g >= 0.35:
                    M += 'R'
                else:
                    M += 'N'
            outline = 'MEME version 4'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'ALPHABET= ACGT'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'strands: + -'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'Background letter frequencies (from uniform background):'
            outfile.write(outline + '\n')
            outline = 'A 0.25000 C 0.25000 G 0.25000 T 0.25000 '
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'MOTIF ' + outprefix + '.' + pattern + '.n_' + str(N)
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'letter-probability matrix: alength= 4 w= ' + str(len(PFM)) + ' nsites= 1 E= 0'
            outfile.write(outline + '\n')
            for (a,c,g,t) in PFM:
                outline = '  ' + str(a) + '	  ' + str(c) + '	  ' + str(g) + '	  ' + str(t)
                outfile.write(outline + '\n')
            outfile.close()

            CWM = []
            ContributionsMatrix = []
            ACGTs = groupC[pattern]['contrib_scores'][:]
            outfile = open(outprefix + '.' + pattern + '.n_' + str(N)  + '_' + subpattern + '.ns_' + str(NS) + '.CWM.meme','w')
            M = ''
            for acgt in ACGTs:
                (a,c,g,t) = (acgt[0],acgt[1],acgt[2],acgt[3])
                ContributionsMatrix.append((a,c,g,t))

            for i in range(len(PFM)):
                (Pa,Pc,Pg,Pt) = PFM[i]
                (Ca,Cc,Cg,Ct) = ContributionsMatrix[i]
                sumCW = math.fabs(Ca)+math.fabs(Cc)+math.fabs(Cg)+math.fabs(Ct)
#                A = Pa*Ca*TEMPERATURE
#                C = Pa*Cc*TEMPERATURE
#                G = Pa*Cg*TEMPERATURE
#                T = Pa*Ct*TEMPERATURE
#                sumSM = np.exp(A) + np.exp(C) + np.exp(G) + np.exp(T)
#                Aexp = np.exp(A)/sumSM
#                Cexp = np.exp(C)/sumSM
#                Gexp = np.exp(G)/sumSM
#                Texp = np.exp(T)/sumSM
#                (a,c,g,t) = (Aexp,Cexp,Gexp,Texp)
                (a,c,g,t) = (sumCW*Pa,sumCW*Pc,sumCW*Pg,sumCW*Pt)
#                sumCP = A + C + G + T
#                (a,c,g,t) = (A/sumCP,C/sumCP,G/sumCP,T/sumCP)
                CWM.append((a,c,g,t))
                if a >= 0.0060:
                    M += 'A'
                elif c >= 0.0060:
                    M += 'C'
                elif g >= 0.0060:
                    M += 'G'
                elif t >= 0.0060:
                    M += 'T'
                elif a >= 0.0030 and c >= 0.0030 and g >= 0.0030:
                    M += 'V'
                elif a >= 0.0030 and c >= 0.0030 and t >= 0.0030:
                    M += 'H'
                elif a >= 0.0030 and g >= 0.0030 and t >= 0.0030:
                    M += 'D'
                elif c >= 0.0030 and g >= 0.0030 and t >= 0.0030:
                    M += 'B'
                elif a >= 0.0035 and c >= 0.0035:
                    M += 'M'
                elif g >= 0.0035 and t >= 0.0035:
                    M += 'K'
                elif a >= 0.0035 and t >= 0.0035:
                    M += 'W'
                elif g >= 0.0035 and c >= 0.0035:
                    M += 'S'
                elif c >= 0.0035 and t >= 0.0035:
                    M += 'Y'
                elif a >= 0.0035 and g >= 0.0035:
                    M += 'R'
                else:
                    M += 'N'
            outline = 'MEME version 4'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'ALPHABET= ACGT'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'strands: + -'
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'Background letter frequencies (from uniform background):'
            outfile.write(outline + '\n')
            outline = 'A 0.25000 C 0.25000 G 0.25000 T 0.25000 '
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'MOTIF ' + outprefix + '.' + pattern + '.n_' + str(N)
            outfile.write(outline + '\n')
            outfile.write('\n')
            outline = 'letter-probability matrix: alength= 4 w= ' + str(len(PFM)) + ' nsites= 1 E= 0'
            outfile.write(outline + '\n')
            for (a,c,g,t) in CWM:
                outline = '  ' + str(a) + '	  ' + str(c) + '	  ' + str(g) + '	  ' + str(t)
                outfile.write(outline + '\n')
            outfile.close()

run()

