##################################
#                                #
# Last modified 2026/01/16       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
import random
import scipy.stats as st
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import logomaker
import pandas as pd

def revComMotif(acgt_array):

    motCopy = []
    newMot = []

    for (A,C,G,T) in acgt_array:
        motCopy.append((A,C,G,T))
    
    motCopy.reverse()

    for (A,C,G,T) in motCopy:
        newMot.append((T,G,C,A))

    return newMot

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input_alignment outfile_prefix [-SVG] [-bits]' % sys.argv[0]
        sys.exit(1)

    input = sys.argv[1]
    outprefix = sys.argv[2]

    print input, outprefix

    doBits = False
    if '-bits' in sys.argv:
        doBits = True

    CWM = []
    B1H = []

    InCWM = False
    InMotif = False
    InB1H = False

    lineslist = open(input)
    for line in lineslist:
        if line.strip() == '':
            continue
        if line.startswith('Final Score:'):
            InCWM = False
            InB1H = False
            InMotif = False
        if line.startswith('CWM:'):
            CWMname = line.strip().split(']')[0].split('[')[1].replace('=','_')
            InCWM = True
            InB1H = False
            InMotif = False
        if line.startswith('B1H:'):
            B1Hname = line.strip().split(']')[0].split('[')[1]
            InCWM = False
            InB1H = True
        if line.startswith('letter-probability matrix:'):
            InMotif = True
            continue
        if InMotif:
            newline = line
            while ' ' in newline:
                newline = newline.replace(' ','\t')
            newline = newline.strip()
            while '\t\t' in newline:
                newline = newline.replace('\t\t','\t')
            fields = newline.split('\t')
            A = float(fields[0])
            C = float(fields[1])
            G = float(fields[2])
            T = float(fields[3])
            if InCWM:
                CWM.append((A,C,G,T))
            if InB1H:
                B1H.append((A,C,G,T))
        if line.startswith('Final Alignment (B1H):'):
            B1Halignment = line.strip().split(']')[0].split('[')[1].replace(' ','').replace("'","").split(',')
        if line.startswith('Final Alignment (CWM):'):
            CWMalignment = line.strip().split(']')[0].split('[')[1].replace(' ','').replace("'","").split(',')

        if line.startswith('============='):

            print outprefix + CWMname + '-vs-' + B1Hname
#            print B1Halignment
#            print CWMalignment
#            print B1H
#            print CWM

            ForLogo = []
            for i in range(len(CWMalignment)):
                a = CWMalignment[i]
                b = B1Halignment[i]
                if a != '-':
                    a = int(a)
                if b != '-':
                    b = int(b)
                ForLogo.append((a,b))

            outfile = open(outprefix + CWMname + '-vs-' + B1Hname +  '.png', 'w')
            outline = 'Pos\tA\tC\tG\tT'
            outfile.write(outline + '\n')
            i=1
            MOT1ticks = []
            for (a,b) in ForLogo:
                if a != '-':
                    MOT1ticks.append(str(a+1))
                    if doBits:
                        (AA,CC,GG,TT) = CWM[a]
                        print AA, CC, GG, TT
                        E = 2 - (-AA*math.log(AA+1e-6,2) - CC*math.log(CC+1e-6,2) - GG*math.log(GG+1e-6,2) -TT*math.log(TT+1e-6,2))
#                        if doFMRC:
#                            (A,C,G,T) = (TT*E,GG*E,CC*E,AA*E)
#                        else:
                        (A,C,G,T) = (AA*E,CC*E,GG*E,TT*E)
                    else:
#                        if doFMRC:
#                            (AA,CC,GG,TT) = MOT1[a]
#                            (A,C,G,T) = (TT,GG,CC,AA)
#                        else:
                        (A,C,G,T) = CWM[a]
                else:
                    MOT1ticks.append(str(a))
                    (A,C,G,T) = (0,0,0,0)
                outline = str(i) + '\t' + str(A) + '\t' + str(C) + '\t' + str(G) + '\t' + str(T)
                outfile.write(outline + '\n')
                i+=1
            outfile.close()

            MOT1_pd = pd.read_csv(outprefix + CWMname + '-vs-' + B1Hname +  '.png', sep="\t", index_col=0)

            outfile = open(outprefix + CWMname + '-vs-' + B1Hname +  '.png', 'w')
            outline = 'Pos\tA\tC\tG\tT'
            outfile.write(outline + '\n')
            i=1
            MOT2ticks = []
            for (a,b) in ForLogo:
#                print a,b, len(CWM), len(B1H)	
                if b != '-':
                    MOT2ticks.append(str(b+1))
                    if doBits:
                        (AA,CC,GG,TT) = B1H[b]
                        E = 2 - (-AA*math.log(AA+1e-6,2) - CC*math.log(CC+1e-6,2) - GG*math.log(GG+1e-6,2) -TT*math.log(TT+1e-6,2))
#                        if doSMRC:
#                            (A,C,G,T) = (TT*E,GG*E,CC*E,AA*E)
#                        else:
                        (A,C,G,T) = (AA*E,CC*E,GG*E,TT*E)
                    else:
#                        if doSMRC:
#                            (AA,CC,GG,TT) = MOT2[b]
#                            (A,C,G,T) = (TT,GG,CC,AA)
#                        else:
                        (A,C,G,T) = B1H[b]
                else:
                    MOT2ticks.append(str(b))
                    (A,C,G,T) = (0,0,0,0)
                outline = str(i) + '\t' + str(A) + '\t' + str(C) + '\t' + str(G) + '\t' + str(T)
                outfile.write(outline + '\n')
                i+=1
            outfile.close()

            MOT2_pd = pd.read_csv(outprefix + CWMname + '-vs-' + B1Hname + '.png', sep="\t", index_col=0)

#            print MOT2_pd

            fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(len(ForLogo)/15*4,3))

            logomaker.Logo(MOT1_pd, ax = ax1, font_name='Arial Rounded MT Bold', show_spines=False)
            logomaker.Logo(MOT2_pd, ax = ax2, font_name='Arial Rounded MT Bold', show_spines=False)

            Xticks = []
            for i in range(len(ForLogo)):
                Xticks.append(i+1)

            ax1.set_xticks(Xticks)
            ax1.set_xticklabels(MOT1ticks)
            ax2.set_xticks(Xticks)
            ax2.set_xticklabels(MOT2ticks)
 
            plt.savefig(outprefix + CWMname + '-vs-' + B1Hname + '.png')
            if '-SVG' in sys.argv:
                plt.savefig(outprefix + '.svg')

            CWM = []
            B1H = []

            InCWM = False
            InMotif = False
            InB1H = False

run()

