##################################
#                                #
# Last modified 2024/01/23       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import h5py
import string
import numpy as np

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s modisco_results.h5 outprefix' % sys.argv[0]
        sys.exit(1)

    MODISCO = sys.argv[1]
    outprefix = sys.argv[2]

    TEMPERATURE = 1

    fMfile = h5py.File(MODISCO, 'r')

    print fMfile

    groupC = fMfile['pos_patterns']
    for pattern in groupC.keys():
#    for pattern in groupC.itervalues():
        PFM = []
        ACGTs = groupC[pattern]['sequence'][:]
        print pattern
        N =  groupC[pattern]['seqlets']['n_seqlets'][:][0]
        outfile = open(outprefix + '.' + pattern + '.n_' + str(N) + '.PWM.meme','w')
        print N, outprefix + '.' + pattern + '.n_' + str(N) + '.meme'
        M = ''
        for acgt in ACGTs:
            (a,c,g,t) = (acgt[0],acgt[1],acgt[2],acgt[3])
            PFM.append((a,c,g,t))
            if a >= 0.60:
                M += 'A'
            elif c >= 0.60:
                M += 'C'
            elif g >= 0.60:
                M += 'G'
            elif t >= 0.60:
                M += 'T'
            elif a >= 0.30 and c >= 0.30 and g >= 0.30:
                M += 'V'
            elif a >= 0.30 and c >= 0.30 and t >= 0.30:
                M += 'H'
            elif a >= 0.30 and g >= 0.30 and t >= 0.30:
                M += 'D'
            elif c >= 0.30 and g >= 0.30 and t >= 0.30:
                M += 'B'
            elif a >= 0.35 and c >= 0.35:
                M += 'M'
            elif g >= 0.35 and t >= 0.35:
                M += 'K'
            elif a >= 0.35 and t >= 0.35:
                M += 'W'
            elif g >= 0.35 and c >= 0.35:
                M += 'S'
            elif c >= 0.35 and t >= 0.35:
                M += 'Y'
            elif a >= 0.35 and g >= 0.35:
                M += 'R'
            else:
                M += 'N'
        outline = 'MEME version 4'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'ALPHABET= ACGT'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'strands: + -'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'Background letter frequencies (from uniform background):'
        outfile.write(outline + '\n')
        outline = 'A 0.25000 C 0.25000 G 0.25000 T 0.25000 '
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'MOTIF ' + outprefix + '.' + pattern + '.n_' + str(N)
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'letter-probability matrix: alength= 4 w= ' + str(len(PFM)) + ' nsites= 1 E= 0'
        outfile.write(outline + '\n')
        for (a,c,g,t) in PFM:
            outline = '  ' + str(a) + '	  ' + str(c) + '	  ' + str(g) + '	  ' + str(t)
            outfile.write(outline + '\n')
        outfile.close()

        CWM = []
        ContributionsMatrix = []
        ACGTs = groupC[pattern]['contrib_scores'][:]
        outfile = open(outprefix + '.' + pattern + '.n_' + str(N) + '.CWM.meme','w')
        M = ''
        for acgt in ACGTs:
            (a,c,g,t) = (acgt[0],acgt[1],acgt[2],acgt[3])
            ContributionsMatrix.append((a,c,g,t))

        for i in range(len(PFM)):
            (Pa,Pc,Pg,Pt) = PFM[i]
            (Ca,Cc,Cg,Ct) = ContributionsMatrix[i]
            sumCW = math.fabs(Ca)+math.fabs(Cc)+math.fabs(Cg)+math.fabs(Ct)
#            A = Pa*Ca*TEMPERATURE
#            C = Pa*Cc*TEMPERATURE
#            G = Pa*Cg*TEMPERATURE
#            T = Pa*Ct*TEMPERATURE
#            sumSM = np.exp(A) + np.exp(C) + np.exp(G) + np.exp(T)
#            Aexp = np.exp(A)/sumSM
#            Cexp = np.exp(C)/sumSM
#            Gexp = np.exp(G)/sumSM
#            Texp = np.exp(T)/sumSM
#            (a,c,g,t) = (Aexp,Cexp,Gexp,Texp)
            (a,c,g,t) = (sumCW*Pa,sumCW*Pc,sumCW*Pg,sumCW*Pt)
#            sumCP = A + C + G + T
#            (a,c,g,t) = (A/sumCP,C/sumCP,G/sumCP,T/sumCP)
            CWM.append((a,c,g,t))
            if a >= 0.0060:
                M += 'A'
            elif c >= 0.0060:
                M += 'C'
            elif g >= 0.0060:
                M += 'G'
            elif t >= 0.0060:
                M += 'T'
            elif a >= 0.0030 and c >= 0.0030 and g >= 0.0030:
                M += 'V'
            elif a >= 0.0030 and c >= 0.0030 and t >= 0.0030:
                M += 'H'
            elif a >= 0.0030 and g >= 0.0030 and t >= 0.0030:
                M += 'D'
            elif c >= 0.0030 and g >= 0.0030 and t >= 0.0030:
                M += 'B'
            elif a >= 0.0035 and c >= 0.0035:
                M += 'M'
            elif g >= 0.0035 and t >= 0.0035:
                M += 'K'
            elif a >= 0.0035 and t >= 0.0035:
                M += 'W'
            elif g >= 0.0035 and c >= 0.0035:
                M += 'S'
            elif c >= 0.0035 and t >= 0.0035:
                M += 'Y'
            elif a >= 0.0035 and g >= 0.0035:
                M += 'R'
            else:
                M += 'N'
        outline = 'MEME version 4'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'ALPHABET= ACGT'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'strands: + -'
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'Background letter frequencies (from uniform background):'
        outfile.write(outline + '\n')
        outline = 'A 0.25000 C 0.25000 G 0.25000 T 0.25000 '
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'MOTIF ' + outprefix + '.' + pattern + '.n_' + str(N)
        outfile.write(outline + '\n')
        outfile.write('\n')
        outline = 'letter-probability matrix: alength= 4 w= ' + str(len(PFM)) + ' nsites= 1 E= 0'
        outfile.write(outline + '\n')
        for (a,c,g,t) in CWM:
            outline = '  ' + str(a) + '	  ' + str(c) + '	  ' + str(g) + '	  ' + str(t)
            outfile.write(outline + '\n')
        outfile.close()

run()

