
# coding: utf-8

# ## Generalized sequence alignment
# **Author:** Alex Tseng
# 
# We wish to construct a dynamic programming algorithm that is more general than most implementations of Smith-Waterman or Needleman-Wunsch. Specifically, we would like a local and global alignment algorithm that allows for an arbitrary scoring function between bases. In particular, this can be used for aligning two binding motifs that may have gaps. The gap penalty is restricted to be affine (i.e. there can be a penalty for opening a gap, separate from the extension of an existing gap).
# 
# ### Algorithm
# Let our two strings be $S$ and $T$, of length $n$ and $m$, respectively. $S_{i}$ is the $i$th character of $S$ (1-indexed). The gap opening penalty is $W_{g}$, and the gap extension penalty is $W_{s}$. $\sigma(s, t)$ is a function that returns a scalar similarity score between two bases.
# 
# The score of our alignment is: $\sum\limits_{s, t : \text{aligned}}\sigma(s, t) + \sum\limits_{x:\text{gap}}(W_{g} + \vert x\vert W_{s})$
# 
# Define a score matrix $V \in \mathbb{R}^{(n + 1) \times (m + 1)}$. $V[i,j]$ is defined as the score of the optimal alignment of $S_{1},...,S_{i}$ and $T_{1},...,T_{j}$.
# 
# Furthermore, define matrices $I_{S},I_{T},M \in \mathbb{R}^{(n + 1) \times (m + 1)}$:
# - $I_{S}[i,j]$ is the score of the best alignment of $S_{1},...,S_{i}$ and $T_{i},...,T_{j}$ that ends with $S_{i}$ matched to a gap.
# - $I_{T}[i,j]$ is the score of the best alignment of $S_{1},...,S_{i}$ and $T_{i},...,T_{j}$ that ends with $T_{j}$ matched to a gap.
# - $M[i,j]$ is the score of the best alignment of $S_{1},...,S_{i}$ and $T_{i},...,T_{j}$ that ends with $S_{i}$ matched to $T_{j}$.
# 
# The recurrence relations are as follows:
# - $V[i,j] = \max\{0, I_{S}[i,j], I_{T}[i,j], M[i,j]\}$
# - $I_{S}[i,j] = \max\{I_{S}[i-1,j] + W_{s}, I_{T}[i-1,j] + W_{g} + W_{s}, M[i-1,j] + W_{g} + W_{s}\}$
# - $I_{T}[i,j] = \max\{I_{S}[i,j-1] + W_{g} + W_{s}, I_{T}[i,j-1] + W_{s}, M[i,j-1] + W_{g} + W_{s}\}$
# - $M[i,j] = \max\{I_{S}[i-1,j-1] + \sigma(S_{i},T_{j}), I_{T}[i-1,j-1] + \sigma(S_{i},T_{j}), M[i-1,j-1] + \sigma(S_{i},T_{j})\}$
# 
# The base cases are as follows:
# - $V[0,0] = I_{S}[0,0] = I_{T}[0,0] = M[0,0] = 0$
# - $I_{S}[i, 0] = W_{g} + iW_{s}$ for $i \in \{1,...,n\}$
# - $I_{T}[0, j] = W_{g} + jW_{s}$ for $j \in \{1,...,m\}$
# 
# ### Global alignment vs local alignment
# A global alignment forces the alignment of all of $S$ to all of $T$. A local alignment only aligns the best substring of $S$ to the best substring of $T$ (contiguous substrings).
# 
# Algorithmically, the difference between a global alignment and local alignment is the following:
# 
# 1. Allowing 0s in the matrix maximum computation
# 
# For a local alignment, the alignment, as it is being built, can be "reset" so that previous parts of an alignment can be forgotten; in practice, this can be implemented by allowing $I_S$, $I_T$, and $M$ to take on values of 0 if preferred. For a global alignment, the recurrence relations are the same as listed above.
# 
# 2. Beginning the traceback
# 
# For a local alignment, tracing back the pointers in the alignment can happen starting in any cell (specifically, we start at the cell with the highest score); this allows any tail of an alignment to be ignored. For a global alignment, tracing back the pointers must start at $V[n,m]$ (i.e. the alignment of the last characters of each substring).
# 
# 3. Ending the traceback
# 
# For a local alignment, we stop the traceback whenever we hit a score of 0, to avoid taking any gaps or mismatches that would cause a lower score than just quitting the alignment. For a global alignment, we need to continue the traceback until we hit $V[0,0]$.
# 
# ### Complexity
# $O(nm)$ space and time

# In[1]:


import numpy as np
import pickle 
import os
import pandas as pd
from modisco.visualization import viz_sequence
import numpy as np 


background = np.array([0.27, 0.23, 0.23, 0.27])

def load_chip_exo_pattern(tf,n=0):
    '''Loads the correspnding chip-exo pfm with highest count: starting from n(descending)'''
    path = '../ChipExo_modeling/motifs/motif_set_12_01_2019/{}'.format(tf)
    os.path.exists(os.path.join(path,'patterns_info.html'))
    #print(os.path.join(path,'patterns_info.html'))
    df = pd.read_html(os.path.join(path,'patterns_info.html'))[0]
    print(df)
    if (n>len(df)):
        print('Invalid : total rows is {}'.format(len(df)))
        return
    pattern_name = df.iloc[n]['name'].replace('/','_')+'.pickle'
    print(pattern_name)
    pattern_path = os.path.join(path,pattern_name)
    pat = pickle.load(open(pattern_path,'rb'))
    return(pat,pattern_name)


def load_chip_seq_pattern_profile_model(tf,dataset='ENCODE',n=0):
    '''Loads the correspnding chip-exo pfm with highest count: starting from n(descending)'''
    path = '../train_profile_models_ChipSeq/{}_models/motifs/Jan27_1_20/{}'.format(dataset,tf)
    os.path.exists(os.path.join(path,'patterns_info.html'))
    #print(os.path.join(path,'patterns_info.html'))
    df = pd.read_html(os.path.join(path,'patterns_info.html'))[0]
    print(df)
    if (n>len(df)):
        print('Invalid : total rows is {}'.format(len(df)))
        return
    pattern_name = df.iloc[n]['name'].replace('/','_')+'.pickle'
    print(pattern_name)
    pattern_path = os.path.join(path,pattern_name)
    pat = pickle.load(open(pattern_path,'rb'))
    return(pat,pattern_name)




def generalized_align(seq_1, seq_2, score_func, gap_open_1=-10000, gap_extend_1=-10000,gap_open_2=0,gap_extend_2=-0.1,local_align=False):
    """
    Computes the optimal local alignment between two sequences.
    Arguments:
        `seq_1`: an indexable sequence of objects to align to `seq_2`
        `seq_2`: an indexable sequence of objects to align to `seq_1`
        `score_func`: a function that takes in an object from `seq_1` and an object from
            `seq_2` (in that order) and returns a scalar alignment score
        `gap_open`: penalty for opening a gap; must be non-positive
        `gap_extend`: penalty for opening a gap; must be non-positive
        `local_align`: if true, perform a local alignment instead of a global alignment
    Returns the score of the best alignment, and the best alignment. The alignment is
    returned as a list of paired indices, denoting which indices are aligned between
    `seq_1` and `seq_2`. "-" indicates a gap. The returned indices are 0-indexed.
    """
    assert gap_open_1 <= 0
    assert gap_extend_1 <= 0
    assert gap_open_2 <= 0
    assert gap_extend_2 <= 0
    n, m = len(seq_1), len(seq_2)
    
    # Define matrices
    V = np.zeros((n + 1, m + 1))
    I_S = np.zeros((n + 1, m + 1))
    I_T = np.zeros((n + 1, m + 1))
    M = np.zeros((n + 1, m + 1))
    P = np.zeros((n + 1, m + 1), dtype=np.int)
    # P stores what path was taken in defining V[i,j]: P[i,j] = argmax{I_S[i,j], I_T[i,j], M[i,j]}
    
    # Base cases
    for i in range(1, n + 1):
        I_S[i, 0] = gap_open_2 + (i * gap_extend_2)
        P[i, 0] = 0
    for j in range(1, m + 1):
        I_T[0, j] = gap_open_1 + (j * gap_extend_1)
        P[0, j] = 1
    # Note, setting P in the base cases is technically only needed for global alignment
        
    # Fill out the matrices
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            I_S[i, j] = max(I_S[i - 1, j], I_T[i - 1, j] + gap_open_2, M[i - 1, j] + gap_open_2) + gap_extend_2
            I_T[i, j] = max(I_S[i, j - 1] + gap_open_1, I_T[i, j - 1], M[i, j - 1] + gap_open_1) + gap_extend_1
            M[i, j] = max(I_S[i - 1, j - 1], I_T[i - 1, j - 1], M[i - 1, j - 1]) + score_func(seq_1[i - 1], seq_2[j - 1])
            
            if local_align:
                # Local alignment: offer the option of resetting
                I_S[i, j] = max(I_S[i, j], 0)
                I_T[i, j] = max(I_T[i, j], 0)
                M[i, j] = max(M[i, j], 0)
            
            scores = [I_S[i, j], I_T[i, j], M[i, j]]
            P[i, j] = np.argmax(scores)
            V[i, j] = scores[P[i, j]]
    
    # Trace back the best alignment
    if local_align:
        max_inds = np.where(V == np.max(V))
        i, j = max_inds[0][0], max_inds[1][0]
        traceback_done = lambda i, j: V[i, j] == 0
    else:
        i, j = n, m
        traceback_done = lambda i, j: i == 0 and j == 0
    final_score = V[i, j]
    alignment = []
    
    while not traceback_done(i, j):
        if P[i, j] == 0:
            # Align S[i] to gap
            i -= 1
            alignment.append((i, "-"))
        elif P[i, j] == 1:
            # Align T[j] to gap
            j -= 1
            alignment.append(("-", j))
        else:
            # Align S[i] to T[j]
            i -= 1
            j -= 1
            alignment.append((i, j))
    return final_score, alignment[::-1]




def print_alignment(seq_1, seq_2, alignment):
    align_s, align_t = "", ""
    for s, t in alignment:
        if s == "-":
            align_s += "-"
            align_t += T[t]
        elif t == "-":
            align_s += S[s]
            align_t += "-"
        else:
            align_s += S[s]
            align_t += T[t]
    print(align_s)
    print(align_t)


def plot_ic_scaled(pwm):
    # Rescale
    
    pwm_sum = np.sum(pwm, axis=1, keepdims=True)
    pwm_sum[pwm_sum == 0] = 1  # Keep 0 where 0
    viz_sequence.plot_weights(viz_sequence.ic_scale(pwm / pwm_sum, background))





def dot_sim(vec_1, vec_2):
    """
    Computes dot product similarity over two vectors.
    """
    return np.sum(vec_1 * vec_2)





def build_aligned_pwms(pwm_1, pwm_2, alignment):
    aligned_1, aligned_2 = np.empty((len(alignment), 4)), np.empty((len(alignment), 4))
    for i, (i_1, i_2) in enumerate(alignment):
        if i_1 == "-":
            aligned_1[i] = background
        else:
            aligned_1[i] = pwm_1[i_1]
        if i_2 == "-":
            aligned_2[i] = background
        else:
            aligned_2[i] = pwm_2[i_2]
    return aligned_1, aligned_2


def build_aligned_pwms_checking_for_RC(b1hrc_motif,invivo_motif,gap_open_1=-10, gap_extend_1=-10,gap_open_2=0,gap_extend_2=-0.1):
    score,alignment = generalized_align(b1hrc_motif, invivo_motif, dot_sim,gap_open_1,gap_extend_1,gap_open_2,gap_extend_2,local_align=True)
    
    print('Score : {}'.format(str(score)))
    score_RC,alignment_RC = generalized_align(b1hrc_motif, invivo_motif[::-1,::-1], dot_sim,gap_open_1,gap_extend_1,gap_open_2,gap_extend_2,local_align=True)
    print('Score_RC : {}'.format(str(score_RC)))
    best_score,best_alignment,best_invivo_motif = None,None,None
    if score>score_RC:
        best_score,best_alignment,best_invivo_motif = score,alignment,invivo_motif
    else:
        best_score,best_alignment,best_invivo_motif = score_RC,alignment_RC,invivo_motif[::-1,::-1]

    b1hrc_motif_align, invivo_motif_align = build_aligned_pwms(b1hrc_motif, best_invivo_motif,best_alignment)    
    return best_score,b1hrc_motif_align,invivo_motif_align



def plot_ic_scaled(pwm):
    # Rescale
    pwm_sum = np.sum(pwm, axis=1, keepdims=True)
    pwm_sum[pwm_sum == 0] = 1  # Keep 0 where 0
    viz_sequence.plot_weights(viz_sequence.ic_scale(pwm / pwm_sum, background))









