##################################
#                                #
# Last modified 2025/08/06       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
import random
import scipy.stats as st
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logomaker
import pandas as pd

def revComMotif(acgt_array):

    motCopy = []
    newMot = []

    for (A,C,G,T) in acgt_array:
        motCopy.append((A,C,G,T))
    
    motCopy.reverse()

    for (A,C,G,T) in motCopy:
        newMot.append((T,G,C,A))

    return newMot

# def generalized_align(seq_1, seq_2, score_func=np.d, gap_open_1=-10000, gap_extend_1=-10000,gap_open_2=0,gap_extend_2=-0.1,local_align=False):
def generalized_align(seq_1, seq_2, gap_open_1, gap_extend_1,gap_open_2,gap_extend_2,local_align):
    """
    Computes the optimal local alignment between two sequences.
    Arguments:
        `seq_1`: an indexable sequence of objects to align to `seq_2`
        `seq_2`: an indexable sequence of objects to align to `seq_1`
        `score_func`: a function that takes in an object from `seq_1` and an object from
            `seq_2` (in that order) and returns a scalar alignment score
        `gap_open`: penalty for opening a gap; must be non-positive
        `gap_extend`: penalty for opening a gap; must be non-positive
        `local_align`: if true, perform a local alignment instead of a global alignment
    Returns the score of the best alignment, and the best alignment. The alignment is
    returned as a list of paired indices, denoting which indices are aligned between
    `seq_1` and `seq_2`. "-" indicates a gap. The returned indices are 0-indexed.
    """
    assert gap_open_1 <= 0
    assert gap_extend_1 <= 0
    assert gap_open_2 <= 0
    assert gap_extend_2 <= 0
    n, m = len(seq_1), len(seq_2)
    
    # Define matrices
    V = np.zeros((n + 1, m + 1))
    I_S = np.zeros((n + 1, m + 1))
    I_T = np.zeros((n + 1, m + 1))
    M = np.zeros((n + 1, m + 1))
    P = np.zeros((n + 1, m + 1), dtype=np.int)
    # P stores what path was taken in defining V[i,j]: P[i,j] = argmax{I_S[i,j], I_T[i,j], M[i,j]}
    
    # Base cases
    for i in range(1, n + 1):
        I_S[i, 0] = gap_open_2 + (i * gap_extend_2)
        P[i, 0] = 0
    for j in range(1, m + 1):
        I_T[0, j] = gap_open_1 + (j * gap_extend_1)
        P[0, j] = 1
    # Note, setting P in the base cases is technically only needed for global alignment
        
    # Fill out the matrices
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            I_S[i, j] = max(I_S[i - 1, j], I_T[i - 1, j] + gap_open_2, M[i - 1, j] + gap_open_2) + gap_extend_2
            I_T[i, j] = max(I_S[i, j - 1] + gap_open_1, I_T[i, j - 1], M[i, j - 1] + gap_open_1) + gap_extend_1
#            M[i, j] = max(I_S[i - 1, j - 1], I_T[i - 1, j - 1], M[i - 1, j - 1]) + score_func(seq_1[i - 1], seq_2[j - 1])
            M[i, j] = max(I_S[i - 1, j - 1], I_T[i - 1, j - 1], M[i - 1, j - 1]) + np.dot(seq_1[i - 1], seq_2[j - 1])
            
            if local_align:
                # Local alignment: offer the option of resetting
                I_S[i, j] = max(I_S[i, j], 0)
                I_T[i, j] = max(I_T[i, j], 0)
                M[i, j] = max(M[i, j], 0)
            
            scores = [I_S[i, j], I_T[i, j], M[i, j]]
            P[i, j] = np.argmax(scores)
            V[i, j] = scores[P[i, j]]

    # Trace back the best alignment
    if local_align:
        max_inds = np.where(V == np.max(V))
        i, j = max_inds[0][0], max_inds[1][0]
        traceback_done = lambda i, j: V[i, j] == 0
    else:
        i, j = n, m
        traceback_done = lambda i, j: i == 0 and j == 0
    (fi, fj) = np.unravel_index(np.argmax(V), V.shape)
#    final_score = V[i, j]
    final_score = V[fi, fj]
    alignment = []

# original traceback:

#    while not traceback_done(i, j):
#        if P[i, j] == 0:
#            # Align S[i] to gap
#            i -= 1
#            alignment.append((i, "-"))
#        elif P[i, j] == 1:
#            # Align T[j] to gap
#            j -= 1
#            alignment.append(("-", j))
#        else:
#            # Align S[i] to T[j]
#            i -= 1
#            j -= 1
#            alignment.append((i, j))

    if fi < n:
        for k in range(n-1,fi-1,-1):
            alignment.append((k, "-"))
    if fj < m:
        for k in range(m-1,fj-1,-1):
            alignment.append(("-", k))
#    alignment.append((fi, fj))
    while fi > 0 or fj > 0:
#        diagonal = V[fi-1,fj-1]
#        up = V[fi-1,fj]
#        left = V[fi,fj-1]
#        if max(diagonal,up,left) == diagonal:
#        print fi, fj, alignment
        if P[fi,fj] == 2:
            # Align S[i] to T[j]
            fi -= 1
            fj -= 1
            if fi == 0 and fj == 0:
                pass
            else:
                alignment.append((fi, fj))
#        elif max(diagonal,up,left) == up:
        if P[fi,fj] == 0:
            # Align S[i] to gap
            fi -= 1
            alignment.append((fi, "-"))
#        elif max(diagonal,up,left) == left:
        if P[fi,fj] == 1:
            # Align T[j] to gap
            fj -= 1
            alignment.append(("-", fj))

#    print V
#    print M
#    print I_S
#    print I_T
#    print P

#    print alignment
#    sys.exit(1)

    return final_score, alignment[::-1]

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s B1H-RC_meme_1 TF-MoDISCo_meme_2 outfile_prefix [-gap_open_1 value] [-gap_extend_1 value] [-gap_open_2 value] [-gap_extend_2 value] [-local_align] [-trimMOT2 minBits] [-bits] [-SVG]' % sys.argv[0]
        sys.exit(1)

    meme1 = sys.argv[1]
    meme2 = sys.argv[2]
    outprefix = sys.argv[3]

    print meme1, meme2, outprefix

    gap_open_1 = -10
    if '-gap_open_1' in sys.argv:
         gap_open_1 = float(sys.argv[sys.argv.index('-gap_open_1') + 1])
         print 'gap_open_1 =', gap_open_1

    gap_extend_1 = -10
    if '-gap_extend_1' in sys.argv:
         gap_extend_1 = float(sys.argv[sys.argv.index('-gap_extend_1') + 1])
         print 'gap_extend_1 =', gap_extend_1

    gap_open_2 = 0
    if '-gap_open_2' in sys.argv:
         gap_open_2 = float(sys.argv[sys.argv.index('-gap_open_2') + 1])
         print 'gap_open_2 =', gap_open_2

    gap_extend_2 = -0.1
    if '-gap_extend_2' in sys.argv:
         gap_extend_2 = float(sys.argv[sys.argv.index('-gap_extend_2') + 1])
         print 'gap_extend_2 =', gap_extend_2

    local_align = False
    if '-local_align' in sys.argv:
         local_align = True
         print 'local_align =', local_align

    doTrimMot2 = False
    minBits = 0
    if '-trimMOT2' in sys.argv:
         doTrimMot2 = True
         minBits = float(sys.argv[sys.argv.index('-trimMOT2') + 1])
         print 'will trim second motif start and end, requirining minimum', minBits, 'bits'

    doBits = False
    if '-bits' in sys.argv:
        doBits = True
        print '[-bits]'

    MOT1 = []
    MOT2 = []

    lineslist = open(meme1)
    InMotif = False
    for line in lineslist:
        if line.startswith('letter-probability matrix:'):
            InMotif = True
            continue
        if InMotif:
            newline = line
            while ' ' in newline:
                newline = newline.replace(' ','\t')
            newline = newline.strip()
            while '\t\t' in newline:
                newline = newline.replace('\t\t','\t')
            fields = newline.split('\t')
            A = float(fields[0])
            C = float(fields[1])
            G = float(fields[2])
            T = float(fields[3])
            if doBits:
                E = 2 - (-A*math.log(A+1e-6,2) - C*math.log(C+1e-6,2) - G*math.log(G+1e-6,2) -T*math.log(T+1e-6,2))
                (AA,CC,GG,TT) = (A*E,C*E,G*E,T*E)
                MOT1.append((AA,CC,GG,TT))
            else:
                MOT1.append((A,C,G,T))

    print 'finished parsing motif file 1'

    lineslist = open(meme2)
    InMotif = False
    for line in lineslist:
        if line.startswith('letter-probability matrix:'):
            InMotif = True
            continue
        if InMotif:
            newline = line
            while ' ' in newline:
                newline = newline.replace(' ','\t')
            newline = newline.strip()
            while '\t\t' in newline:
                newline = newline.replace('\t\t','\t')
            fields = newline.split('\t')
            A = float(fields[0])
            C = float(fields[1])
            G = float(fields[2])
            T = float(fields[3])
            if doBits:
                E = 2 - (-A*math.log(A+1e-6,2) - C*math.log(C+1e-6,2) - G*math.log(G+1e-6,2) -T*math.log(T+1e-6,2))
                (AA,CC,GG,TT) = (A*E,C*E,G*E,T*E)
                MOT2.append((AA,CC,GG,TT))
            else:
                MOT2.append((A,C,G,T))

    print 'finished parsing motif file 2'
    print MOT2

    newstart = len(MOT2)
    newend = len(MOT2)

    if doTrimMot2:
        Epos = []
        for (A,C,G,T) in MOT2:
            if doBits:
                Epos.append(A+C+G+T)
            else:
                E = 2 - (-A*math.log(A+1e-6,2) - C*math.log(C+1e-6,2) - G*math.log(G+1e-6,2) -T*math.log(T+1e-6,2))
                Epos.append(E)
        for i in range(len(Epos)):
            if Epos[i] >= minBits:
                newstart = i
                break
        for i in reversed(range(len(Epos))):
            if Epos[i] >= minBits:
                newend = i
                break
        print 'trim positions:', newstart, newend+1
        newMOT2 = MOT2[newstart:newend+1]
        MOT2 = newMOT2

    outfile = open(outprefix + '.alignment','w')
    outfilename = outprefix + '.alignment'

    MOT1RC = revComMotif(MOT1)
    MOT2RC = revComMotif(MOT2)

#    print MOT2
#    print MOT2RC

    (final_scoreFF, final_alignmentFF) = generalized_align(MOT1, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreFR, final_alignmentFR) = generalized_align(MOT1, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreRR, final_alignmentRR) = generalized_align(MOT1RC, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreRF, final_alignmentRF) = generalized_align(MOT1RC, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)

    print 'final score For-For', final_scoreFF
    print 'final score For-Rev', final_scoreFR
    print 'final score Rev-Rev', final_scoreRR
    print 'final score Rev-For', final_scoreRF

    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreFF:

        total_shuffling_scores = []
        shuffledB1H = []
        for (a,c,g,t) in MOT1:
            shuffledB1H.append((a,c,g,t))
        for i in range(1000):
            random.shuffle(shuffledB1H)
            (FS, FA) = generalized_align(shuffledB1H, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            total_shuffling_scores.append(FS)

        ZF_shuffling_scores = []
        ZFshuffledB1H = []
        for i in range(0,len(MOT1),3):
            ZF1 = MOT1[i]
            ZF2 = MOT1[i+1]
            ZF3 = MOT1[i+2]
            ZFshuffledB1H.append((ZF1,ZF2,ZF3))
        for i in range(1000):
            random.shuffle(ZFshuffledB1H)
            newMOT = []
            for (ZF1,ZF2,ZF3) in ZFshuffledB1H:
                newMOT.append(ZF1)
                newMOT.append(ZF2)
                newMOT.append(ZF3)
            (FS, FA) = generalized_align(newMOT, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            ZF_shuffling_scores.append(FS)

    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreFR:

        total_shuffling_scores = []
        shuffledB1H = []
        for (a,c,g,t) in MOT1:
            shuffledB1H.append((a,c,g,t))
        for i in range(1000):
            random.shuffle(shuffledB1H)
            (FS, FA) = generalized_align(shuffledB1H, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            total_shuffling_scores.append(FS)

        ZF_shuffling_scores = []
        ZFshuffledB1H = []
        for i in range(0,len(MOT1),3):
            ZF1 = MOT1[i]
            ZF2 = MOT1[i+1]
            ZF3 = MOT1[i+2]
            ZFshuffledB1H.append((ZF1,ZF2,ZF3))
        for i in range(1000):
            random.shuffle(ZFshuffledB1H)
            newMOT = []
            for (ZF1,ZF2,ZF3) in ZFshuffledB1H:
                newMOT.append(ZF1)
                newMOT.append(ZF2)
                newMOT.append(ZF3)
            (FS, FA) = generalized_align(newMOT, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            ZF_shuffling_scores.append(FS)

    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreRR:

        total_shuffling_scores = []
        shuffledB1H = []
        for (a,c,g,t) in MOT1RC:
            shuffledB1H.append((a,c,g,t))
        for i in range(1000):
            random.shuffle(shuffledB1H)
            (FS, FA) = generalized_align(shuffledB1H, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            total_shuffling_scores.append(FS)

        ZF_shuffling_scores = []
        ZFshuffledB1H = []
        for i in range(0,len(MOT1RC),3):
            ZF1 = MOT1RC[i]
            ZF2 = MOT1RC[i+1]
            ZF3 = MOT1RC[i+2]
            ZFshuffledB1H.append((ZF1,ZF2,ZF3))
        for i in range(1000):
            random.shuffle(ZFshuffledB1H)
            newMOT = []
            for (ZF1,ZF2,ZF3) in ZFshuffledB1H:
                newMOT.append(ZF1)
                newMOT.append(ZF2)
                newMOT.append(ZF3)
            (FS, FA) = generalized_align(newMOT, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            ZF_shuffling_scores.append(FS)

    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreRF:
        
        total_shuffling_scores = []
        shuffledB1H = []
        for (a,c,g,t) in MOT1RC:
            shuffledB1H.append((a,c,g,t))
        for i in range(1000):
            random.shuffle(shuffledB1H)
            (FS, FA) = generalized_align(shuffledB1H, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            total_shuffling_scores.append(FS)

        ZF_shuffling_scores = []
        ZFshuffledB1H = []
        for i in range(0,len(MOT1RC),3):
            ZF1 = MOT1RC[i]
            ZF2 = MOT1RC[i+1]
            ZF3 = MOT1RC[i+2]
            ZFshuffledB1H.append((ZF1,ZF2,ZF3))
        for i in range(1000):
            random.shuffle(ZFshuffledB1H)
            newMOT = []
            for (ZF1,ZF2,ZF3) in ZFshuffledB1H:
                newMOT.append(ZF1)
                newMOT.append(ZF2)
                newMOT.append(ZF3)
            (FS, FA) = generalized_align(newMOT, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
            ZF_shuffling_scores.append(FS)

    TSmean = np.mean(total_shuffling_scores)
    TSstd = np.std(total_shuffling_scores)
    TSpval = 1 - st.norm.cdf(max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR),TSmean,TSstd)

    ZFmean = np.mean(ZF_shuffling_scores)
    ZFstd = np.std(ZF_shuffling_scores)
    ZFpval = 1 - st.norm.cdf(max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR),ZFmean,ZFstd)

    print '1-shuffle:', TSmean, TSstd, TSpval, 1-TSpval
    print '3-shuffle:', ZFmean, ZFstd, ZFpval, 1-ZFpval

    outline = '#alignment:'

    ForLogo = []

    (final_scoreFF, final_alignmentFF) = generalized_align(MOT1, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreFR, final_alignmentFR) = generalized_align(MOT1, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreRR, final_alignmentRR) = generalized_align(MOT1RC, MOT2RC, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    (final_scoreRF, final_alignmentRF) = generalized_align(MOT1RC, MOT2, gap_open_1, gap_extend_1, gap_open_2, gap_extend_2, local_align)
    print 'final score For-For', final_scoreFF
    print 'final score For-Rev', final_scoreFR
    print 'final score Rev-Rev', final_scoreRR
    print 'final score Rev-For', final_scoreRF

    final_scoreRR = 0
    final_scoreRF = 0

    doFMRC = False
    doSMRC = False
    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreFF:
        for (a,b) in final_alignmentFF:
            outline = '##    ' + str(a) + '\t' + str(b)
            ForLogo.append((a,b))
            outfile.write(outline + '\n')
    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreFR:
        for (a,b) in final_alignmentFR:
            if b != '-':
                outline = '##    ' + str(a) + '\t' + str(len(MOT2RC) - b -1)
                ForLogo.append((a,len(MOT2RC) - b - 1))
            else:
                outline = '##    ' + str(a) + '\t' + str(b)
                ForLogo.append((a,b))
            outfile.write(outline + '\n')
        doSMRC = True
    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreRR:
        for (a,b) in final_alignmentFR:
            if b != '-' and a != '-':
                outline = '##    ' + str(len(MOT1RC) - a - 1) + '\t' + str(len(MOT2RC) - b -1)
                ForLogo.append((len(MOT1RC) - a - 1,len(MOT2RC) - b - 1))
            elif b != '-' and a == '-':
                outline = '##    ' + str(a) + '\t' + str(len(MOT2RC) - b -1)
                ForLogo.append((a,len(MOT2RC) - b - 1))
            elif b == '-' and a != '-':
                outline = '##    ' + str(len(MOT1RC) - a - 1) + '\t' + str(b)
                ForLogo.append((len(MOT1RC) - a - 1,b))
            else:
                outline = '##    ' + str(a) + '\t' + str(b)
                ForLogo.append((a,b))
            outfile.write(outline + '\n')
        doSMRC = True
        doFMRC = True
    if max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR) == final_scoreRF:
        for (a,b) in final_alignmentFR:
            if b != '-' and a != '-':
                outline = '##    ' + str(len(MOT1RC) - a - 1) + '\t' + str(len(MOT2) - b -1)
                ForLogo.append((len(MOT1RC) - a - 1,len(MOT2) - b - 1))
            elif b != '-' and a == '-':
                outline = '##    ' + str(a) + '\t' + str(len(MOT2) - b -1)
                ForLogo.append((a,len(MOT2) - b - 1))
            elif b == '-' and a != '-':
                outline = '##    ' + str(len(MOT1RC) - a - 1) + '\t' + str(b)
                ForLogo.append((len(MOT1RC) - a - 1,b))
            else:
                outline = '##    ' + str(a) + '\t' + str(b)
                ForLogo.append((a,b))
            outfile.write(outline + '\n')
        doFMRC = True

    print ForLogo

    outline = '#comparison\tfinal_score\t1-bp_shuffle_mean\t1-bp_shuffle_std\t1-bp_shuffle_p-val\t3-bp_shuffle_mean\t3-bp_shuffle_std\t3-bp_shuffle_p-val'
    outfile.write(outline + '\n')
    outline = outfilename + '\t' + str(max(final_scoreFF,final_scoreFR,final_scoreRF,final_scoreRR)) + '\t' + str(TSmean) + '\t' + str(TSstd)  + '\t' + str(min(TSpval, 1-TSpval)) + '\t'
    outline = outline + str(ZFmean) + '\t' + str(ZFstd)  + '\t' + str(min(ZFpval, 1-ZFpval))
    outfile.write(outline + '\n')
    outfile.close()

    outfile = open(outprefix + '.png', 'w')
    outline = 'Pos\tA\tC\tG\tT'
    outfile.write(outline + '\n')
    i=1
    MOT1ticks = []
    for (a,b) in ForLogo:
        print a,b
        if a != '-':
            MOT1ticks.append(str(a+1))
            if not doBits:
                (AA,CC,GG,TT) = MOT1[a]
                E = 2 - (-AA*math.log(AA+1e-6,2) - CC*math.log(CC+1e-6,2) - GG*math.log(GG+1e-6,2) -TT*math.log(TT+1e-6,2))
                if doFMRC:
                    (A,C,G,T) = (TT*E,GG*E,CC*E,AA*E)
                else:
                    (A,C,G,T) = (AA*E,CC*E,GG*E,TT*E)
            else:
                if doFMRC:
                    (AA,CC,GG,TT) = MOT1[a]
                    (A,C,G,T) = (TT,GG,CC,AA)
                else:
                    (A,C,G,T) = MOT1[a]
        else:
            MOT1ticks.append(str(a))
            (A,C,G,T) = (0,0,0,0)
        outline = str(i) + '\t' + str(A) + '\t' + str(C) + '\t' + str(G) + '\t' + str(T)
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    MOT1_pd = pd.read_csv(outprefix + '.png', sep="\t", index_col=0)

    outfile = open(outprefix + '.png', 'w')
    outline = 'Pos\tA\tC\tG\tT'
    outfile.write(outline + '\n')
    i=1
    MOT2ticks = []
    for (a,b) in ForLogo:
        print a,b, len(MOT1), len(MOT2)	
        if b != '-':
            MOT2ticks.append(str(b+1))
            if not doBits:
                (AA,CC,GG,TT) = MOT2[b]
                E = 2 - (-AA*math.log(AA+1e-6,2) - CC*math.log(CC+1e-6,2) - GG*math.log(GG+1e-6,2) -TT*math.log(TT+1e-6,2))
                if doSMRC:
                    (A,C,G,T) = (TT*E,GG*E,CC*E,AA*E)
                else:
                    (A,C,G,T) = (AA*E,CC*E,GG*E,TT*E)
            else:
                if doSMRC:
                    (AA,CC,GG,TT) = MOT2[b]
                    (A,C,G,T) = (TT,GG,CC,AA)
                else:
                    (A,C,G,T) = MOT2[b]
        else:
            MOT2ticks.append(str(b))
            (A,C,G,T) = (0,0,0,0)
        outline = str(i) + '\t' + str(A) + '\t' + str(C) + '\t' + str(G) + '\t' + str(T)
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    MOT2_pd = pd.read_csv(outprefix + '.png', sep="\t", index_col=0)

#    print MOT2_pd

    fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(len(ForLogo)/15*4,3))

    logomaker.Logo(MOT1_pd, ax = ax1, font_name='Arial Rounded MT Bold', show_spines=False)
    logomaker.Logo(MOT2_pd, ax = ax2, font_name='Arial Rounded MT Bold', show_spines=False)

    Xticks = []
    for i in range(len(ForLogo)):
        Xticks.append(i+1)

    ax1.set_xticks(Xticks)
    ax1.set_xticklabels(MOT1ticks)
    ax2.set_xticks(Xticks)
    ax2.set_xticklabels(MOT2ticks)
 
    plt.savefig(outprefix + '.png')
    if '-SVG' in sys.argv:
        plt.savefig(outprefix + '.svg')

run()

