{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "48d90caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "import numpy as np\n",
    "import cupy as cp\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import logomaker\n",
    "\n",
    "import MotifCompendium\n",
    "import MotifCompendium.utils.loader as utils_loader\n",
    "import MotifCompendium.utils.motif as utils_motif"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "045e5b1b-d339-4a86-b94e-d154e371d1ac",
   "metadata": {},
   "source": [
    "# Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a58e6d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "###########\n",
    "# FINGERS #\n",
    "###########\n",
    "def unroll_znf_b1h(b1h_ppm):\n",
    "    assert(len(b1h_ppm.shape) == 3 and b1h_ppm.shape[0] == 1 and b1h_ppm.shape[2] == 4)\n",
    "    b1h_ppm = b1h_ppm[0] # (L, 4)\n",
    "    L = b1h_ppm.shape[0]\n",
    "    assert(L%3 == 0)\n",
    "    num_fingers = L // 3\n",
    "    b1h_ppm = b1h_ppm.reshape(num_fingers, 3, 4)\n",
    "    return b1h_ppm\n",
    "\n",
    "\n",
    "#############\n",
    "# UNROLLING #\n",
    "#############\n",
    "def unroll_motif_cp(motif, U):\n",
    "    L, K = motif.shape\n",
    "    unrolled_motif = _tensor3_matmul_tensor2(_UNRAVELLER(L, U), motif)\n",
    "    unrolled_motif = unrolled_motif.transpose((1, 0, 2))\n",
    "    assert(unrolled_motif.shape == (L-U+1, U, K))\n",
    "    return unrolled_motif\n",
    "\n",
    "\n",
    "_UNRAVELLER_TENSOR = None\n",
    "def _UNRAVELLER(L, U):\n",
    "    # L = length of input sequence\n",
    "    # U = length to unravel to\n",
    "    # Unravels (L, K) --> (L-U+1, U, K)\n",
    "    global _UNRAVELLER_TENSOR\n",
    "    create_tensor = False\n",
    "    if _UNRAVELLER_TENSOR is None:\n",
    "        create_tensor = True\n",
    "    elif _UNRAVELLER_TENSOR.shape != (U, L-U+1, L):\n",
    "        create_tensor = True\n",
    "    if create_tensor:\n",
    "        _UNRAVELLER_TENSOR = cp.zeros((U, L-U+1, L))\n",
    "        for i in range(U):\n",
    "            _UNRAVELLER_TENSOR[i, :, i:i+L-U+1] = cp.eye(L-U+1)\n",
    "    return _UNRAVELLER_TENSOR\n",
    "\n",
    "\n",
    "def _tensor3_matmul_tensor2(x, y):\n",
    "    \"\"\"Multiplies a (N, L, K) tensor with a (K, M) tensor efficiently.\"\"\"\n",
    "    N, L, K = x.shape\n",
    "    M = y.shape[1]\n",
    "    x_flat = x.reshape(N * L, K)  # (NL, K)\n",
    "    result = x_flat @ y  # (NL, M)\n",
    "    return result.reshape(N, L, M)  # (N, L, M)\n",
    "\n",
    "\n",
    "##############\n",
    "# SIMILARITY #\n",
    "##############\n",
    "def compute_aligned_similarity(x, y):\n",
    "    \"\"\"Computes the similarity of a (N, L, K) tensor with a (M, L, K) tensor.\"\"\"\n",
    "    N, L, K = x.shape\n",
    "    M = y.shape[0]\n",
    "    x_normalized = x/cp.linalg.norm(x, axis=(1, 2), keepdims=True)\n",
    "    y_normalized = y/cp.linalg.norm(y, axis=(1, 2), keepdims=True)\n",
    "    x_2d = x_normalized.reshape(N, L*K)\n",
    "    y_2d = y_normalized.reshape(M, L*K)\n",
    "    return (x_2d @ y_2d.T).get()\n",
    "\n",
    "\n",
    "###################\n",
    "# SMITH-WATERMAN  #\n",
    "###################\n",
    "def modified_smith_waterman_matrix_triplet(\n",
    "    matrix_f, matrix_l,\n",
    "    sim_for, sim_rev, \n",
    "    scale_factor, sim_threshold,\n",
    "    overlap_penalty_f1, overlap_penalty_f2,\n",
    "    skip_penalty_f, skip_penalty_l,\n",
    "):\n",
    "    '''\n",
    "    Arguments:\n",
    "        matrix_f: np.array of shape (F, 3, D) - Finger motif matrices\n",
    "        matrix_l: np.array of shape (L, D) - CWM motif matrix\n",
    "        sim_for: np.array of shape (F, L2) - Similarity scores (forward)\n",
    "        sim_rev: np.array of shape (F, L2) - Similarity scores (reverse)\n",
    "        scale_factor: np.array of shape (L2,) - Scaling factors for each CWM position\n",
    "        sim_threshold: float - Minimum similarity to consider a match\n",
    "        overlap_penalty_f1: float - Penalty for 1 bp overlap (finger)\n",
    "        overlap_penalty_f2: float - Penalty for 2 bp overlap (finger)\n",
    "        skip_penalty_f: float - Penalty for skipping a finger\n",
    "        skip_penalty_l: float - Penalty for skipping a CWM position\n",
    "    \n",
    "    Returns:\n",
    "        final_score: float - Optimal alignment score\n",
    "        final_align: list of str - Optimal alignment arrangement\n",
    "        final_orient: str - Optimal orientation (\"for\" or \"rev\")\n",
    "        aligned_motif_fs: np.array of shape (F, L+2, D) - Aligned finger motifs\n",
    "        H_final: np.array of shape (F+1, L2+3) - Final score matrix\n",
    "        A_final: list of list of str - Final arrangement matrix\n",
    "    '''\n",
    "    assert matrix_f.shape[-1] == matrix_l.shape[-1]\n",
    "    assert sim_for.shape == sim_rev.shape\n",
    "    assert sim_threshold < 1\n",
    "    assert overlap_penalty_f1 <= 0\n",
    "    assert overlap_penalty_f2 <= 0\n",
    "    assert skip_penalty_f <= 0\n",
    "    assert skip_penalty_l <= 0\n",
    "    \n",
    "    F, L2 = sim_for.shape\n",
    "    L, D = matrix_l.shape\n",
    "\n",
    "    H_final = np.zeros((F+1, L2+3)) # Optimal score (float)\n",
    "    A_final = [[[] for _ in range(L2+3)] for _ in range(F+1)] # Optimal arrangement (str)\n",
    "    final_score = 0\n",
    "    final_align = []\n",
    "    final_orient = None\n",
    "    final_scale = np.zeros(F)\n",
    "\n",
    "    for (orient, sim) in [(\"for\", sim_for), (\"rev\", sim_rev)]:\n",
    "        # Initialize matrices\n",
    "        H = np.zeros((F+1, L2+3)) # Optimal score (float)\n",
    "        A = [[[] for _ in range(L2+3)] for _ in range(F+1)] # Optimal arrangement (str)\n",
    "        best_score = 0\n",
    "        best_align = []\n",
    "        best_scale = np.zeros(F)\n",
    "\n",
    "        # F: Fingers\n",
    "        for f in range(F):\n",
    "            # L: CWM Lengths\n",
    "            for l in range(L2):\n",
    "                idx_f = f+1 # f index for matrices (shifted by 1)\n",
    "                idx_l = l+3 # l index for matrices (shifted by 3)\n",
    "                \n",
    "                ### Match\n",
    "                sim_score = sim[f, l]\n",
    "                match_score = scale_factor[l] * sim_score\n",
    "                diag1_score, diag2_score, diag3_score = 0, 0, 0\n",
    "                diag1_align, diag2_align, diag3_align = [], [], []\n",
    "                if sim_score > sim_threshold:\n",
    "                    ## Diagonal 1: [-1, -1]\n",
    "                    diag1_prev_score = H[idx_f-1, idx_l-1]\n",
    "                    diag1_prev_align = A[idx_f-1][idx_l-1]\n",
    "                    diag1_score = diag1_prev_score + match_score\n",
    "                    diag1_align = diag1_prev_align + [f\"F{f}@match@L{l}\"]\n",
    "\n",
    "                    for i, align in enumerate(diag1_prev_align):\n",
    "                        # Overlap: 2 bp\n",
    "                        if align.startswith(\"F\") and align.endswith(f\"@match@L{l-1}\"):\n",
    "                            diag1_score = diag1_prev_score + overlap_penalty_f1 + overlap_penalty_f2 + match_score\n",
    "                            diag1_align = diag1_prev_align + [f\"L{l-2}@overlap\"] + [f\"L{l-1}@overlap\"] + [f\"F{f}@match@L{l}\"]\n",
    "                            break\n",
    "                        # Overlap: 1 bp\n",
    "                        elif align.startswith(\"F\") and align.endswith(f\"@match@L{l-2}\"):\n",
    "                            diag1_score = diag1_prev_score + overlap_penalty_f1 + match_score\n",
    "                            diag1_align = diag1_prev_align + [f\"L{l-1}@overlap\"] + [f\"F{f}@match@L{l}\"]\n",
    "                            break\n",
    "\n",
    "                    ## Diagonal 2: [-1, -2]\n",
    "                    diag2_prev_score = H[idx_f-1, idx_l-2]\n",
    "                    diag2_prev_align = A[idx_f-1][idx_l-2]\n",
    "                    diag2_score = diag2_prev_score + match_score\n",
    "                    diag2_align = diag2_prev_align + [f\"F{f}@match@L{l}\"]\n",
    "\n",
    "                    for i, align in enumerate(diag2_prev_align):\n",
    "                        # Overlap: 1 bp\n",
    "                        if align.startswith(\"F\") and align.endswith(f\"@match@L{l-2}\"):\n",
    "                            diag2_score = diag2_prev_score + overlap_penalty_f1 + match_score\n",
    "                            diag2_align = diag2_prev_align + [f\"L{l-1}@overlap\"] + [f\"F{f}@match@L{l}\"]\n",
    "                            break\n",
    "                    \n",
    "                    ## Diagonal 3: [-1, -3]\n",
    "                    diag3_prev_score = H[idx_f-1, idx_l-3]\n",
    "                    diag3_prev_align = A[idx_f-1][idx_l-3]\n",
    "                \n",
    "                ## Down: Skip finger\n",
    "                down_prev_score = H[idx_f-1, idx_l]\n",
    "                down_prev_align = A[idx_f-1][idx_l]\n",
    "                \n",
    "                down_score = down_prev_score + skip_penalty_f\n",
    "                down_align = down_prev_align + [f\"F{f}@skip\"]\n",
    "\n",
    "                ## Right: Skip CWM position\n",
    "                right_prev_score = H[idx_f, idx_l-1]\n",
    "                right_prev_align = A[idx_f][idx_l-1]\n",
    "\n",
    "                # Consecutive skips\n",
    "                l_skip = 1\n",
    "                for right_align_i in right_prev_align[::-1]:\n",
    "                    if right_align_i == f\"L{l - l_skip}@skip\":\n",
    "                        l_skip += 1\n",
    "                    elif right_align_i.startswith(\"L\") and right_align_i.endswith(\"@skip\"):\n",
    "                        break\n",
    "                \n",
    "                right_score = right_prev_score + min(skip_penalty_l * (l_skip - 2), 0)\n",
    "                right_align = right_prev_align + [f\"L{l}@skip\"]\n",
    "                \n",
    "                # Consider all options\n",
    "                options_score = np.array([\n",
    "                    diag1_score,  # Option 1-1: Bind finger f @ l (from state: [-1, -1])\n",
    "                    diag2_score,  # Option 1-2: Bind finger f @ l (from state: [-1, -2])\n",
    "                    diag3_score,  # Option 1-3: Bind finger f @ l (from state: [-1, -3])\n",
    "                    down_score,  # Option 2: Keep finger f unused\n",
    "                    right_score  # Option 3: Bind finger f previous to l\n",
    "                ])\n",
    "\n",
    "                options_align = [\n",
    "                    diag1_align,  # Option 1-1: Bind finger f @ l (from state: [-1, -1])\n",
    "                    diag2_align,  # Option 1-2: Bind finger f @ l (from state: [-1, -2])\n",
    "                    diag3_align,  # Option 1-3: Bind finger f @ l (from state: [-1, -3])\n",
    "                    down_align,  # Option 2: Keep finger f unused\n",
    "                    right_align  # Option 3: Bind finger f previous to l\n",
    "                ]\n",
    "\n",
    "                options_scale = [\n",
    "                    scale_factor[l],  # Option 1-1: Bind finger f @ l (from state: [-1, -1])\n",
    "                    scale_factor[l],  # Option 1-2: Bind finger f @ l (from state: [-1, -2])\n",
    "                    scale_factor[l],  # Option 1-3: Bind finger f @ l (from state: [-1, -3])\n",
    "                    0,  # Option 2: Keep finger f unused\n",
    "                    0  # Option 3: Bind finger f previous to l\n",
    "                ]\n",
    "                \n",
    "                # Select best option\n",
    "                bestoption_idx = np.argmax(options_score)\n",
    "                bestoption_score = max(options_score[bestoption_idx], 0) # Local alignment: no negative scores\n",
    "                bestoption_align = options_align[bestoption_idx]\n",
    "                bestoption_scale = options_scale[bestoption_idx]\n",
    "\n",
    "                # Record\n",
    "                H[idx_f, idx_l] = bestoption_score\n",
    "                A[idx_f][idx_l] = bestoption_align\n",
    "                \n",
    "                if bestoption_score > best_score:\n",
    "                    best_score = bestoption_score\n",
    "                    best_align = bestoption_align\n",
    "                    best_scale[f] = bestoption_scale\n",
    "\n",
    "        # Record final optimal\n",
    "        best_score = H.max()\n",
    "        if best_score >= final_score:\n",
    "            H_final = H\n",
    "            A_final = A\n",
    "            final_score = best_score\n",
    "            final_align = best_align\n",
    "            final_orient = orient\n",
    "            final_scale = best_scale\n",
    "        \n",
    "    # Traceback\n",
    "    aligned_motif_fs = np.zeros((F, L + 2, D))\n",
    "    for align in final_align:\n",
    "        if align.startswith(\"F\"):\n",
    "            f_idx = int(align.split(\"@\")[0][1:])\n",
    "            f_action = align.split(\"@\")[1]\n",
    "            if f_action == \"skip\":\n",
    "                continue\n",
    "            elif f_action == \"match\":\n",
    "                l_idx = int(align.split(\"@\")[-1][1:])\n",
    "\n",
    "            submatrix_f = matrix_f[f_idx]\n",
    "            scaled_submatrix_f = submatrix_f * final_scale[f_idx]\n",
    "            aligned_motif_fs[f_idx, l_idx:l_idx + 3, :] = scaled_submatrix_f\n",
    "    \n",
    "    return final_score, final_align, final_orient, aligned_motif_fs, H_final, A_final\n",
    "\n",
    "\n",
    "##############\n",
    "# VISUALIZE  #\n",
    "##############\n",
    "def visualize_motif(motif, save_path):\n",
    "    \"\"\"\n",
    "    motif: np.array of shape (L, 4)\n",
    "    Draws the motif logo and saves to save_path.\n",
    "    \"\"\"\n",
    "    L, D = motif.shape\n",
    "    assert D == 4\n",
    "\n",
    "    fig, ax = plt.subplots(\n",
    "        1, 1,\n",
    "        figsize=(L / 2, 2),\n",
    "        squeeze=False\n",
    "    )\n",
    "\n",
    "    df = pd.DataFrame(motif, columns=[\"A\", \"C\", \"G\", \"T\"])\n",
    "    logomaker.Logo(df, ax=ax[0, 0])\n",
    "\n",
    "    ax[0, 0].set_xticks([])\n",
    "    ax[0, 0].set_yticks([])\n",
    "    ax[0, 0].set_xlabel(\"\")\n",
    "    ax[0, 0].set_ylabel(\"\")\n",
    "    ymin, ymax = ax.get_ylim()\n",
    "    if ymin == ymax:\n",
    "        ax.set_ylim(0, 1)  # Safeguard against zero positions\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, bbox_inches=\"tight\", dpi=200)\n",
    "    plt.close(fig)\n",
    "\n",
    "def visualize_motifs_n(motifs, save_path):\n",
    "    \"\"\"\n",
    "    motifs: np.array of shape (N, L, 4)\n",
    "    Draws all motifs vertically stacked, and adds \n",
    "    vertical alignment lines where each motif's non-zero \n",
    "    region begins and ends (computed per motif, drawn on all axes).\n",
    "    \"\"\"\n",
    "    N, L, D = motifs.shape\n",
    "    assert D == 4\n",
    "\n",
    "    # --------- Compute alignment boundaries ----------\n",
    "    # For each motif: find first and last non-zero column\n",
    "    starts = []\n",
    "    ends = []\n",
    "\n",
    "    for motif in motifs:\n",
    "        # Sum across A/C/G/T -> shape (L,)\n",
    "        col_sum = motif.sum(axis=1)\n",
    "\n",
    "        nz = np.nonzero(col_sum)[0]\n",
    "        if len(nz) == 0:\n",
    "            # Motif entirely zero → no vertical lines\n",
    "            starts.append(None)\n",
    "            ends.append(None)\n",
    "        else:\n",
    "            starts.append(int(nz[0]))\n",
    "            ends.append(int(nz[-1]))\n",
    "\n",
    "    # --------- Create figure ----------\n",
    "    fig, axes = plt.subplots(\n",
    "        N, 1,\n",
    "        figsize=(L / 2, 2 * N),\n",
    "        squeeze=False\n",
    "    )\n",
    "\n",
    "    for i, motif in enumerate(motifs):\n",
    "        ax = axes[i, 0]\n",
    "\n",
    "        # Make background green for the first motif\n",
    "        if i == 0:\n",
    "            ax.set_facecolor('lightgreen')\n",
    "\n",
    "        # Draw motif logo\n",
    "        df = pd.DataFrame(motif, columns=[\"A\", \"C\", \"G\", \"T\"])\n",
    "        logomaker.Logo(df, ax=ax)\n",
    "\n",
    "        # Draw vertical lines for *all* motif boundaries\n",
    "        for s, e in zip(starts, ends):\n",
    "            if s is not None:\n",
    "                ax.axvline(s, color=\"red\", linestyle=\"--\", linewidth=1)\n",
    "            if e is not None:\n",
    "                ax.axvline(e, color=\"blue\", linestyle=\"--\", linewidth=1)\n",
    "\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlabel(\"\")\n",
    "        ax.set_ylabel(f\"Motif {i+1}\")\n",
    "        ymin, ymax = ax.get_ylim()\n",
    "        if ymin == ymax:\n",
    "            ax.set_ylim(0, 1)  # Safeguard against zero positions\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, bbox_inches=\"tight\", dpi=200)\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f53d97b",
   "metadata": {},
   "source": [
    "# Test: One motif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0c134a-c12f-4e00-8fec-4cbb2a7be150",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Variables\n",
    "tf = \"ZNF143\"\n",
    "encid = \"ENCSR000DZL\"\n",
    "head_type = \"profile\"\n",
    "b1h_dir = \"/oak/stanford/groups/akundaje/marinovg/papers/2023_ZNFs/2023-09-01-final-figures/gencode.v29\"\n",
    "modisco_dir = \"/oak/stanford/groups/akundaje/vir/tfatlas/modisco/release_run_1/trim20_flank5_10/meanshap/ENCSR000DZL/modisco/modisco_profile/profile_scores.h5\"\n",
    "output_dir = './'\n",
    "\n",
    "b1h_pfm_path = os.path.join(b1h_dir, f\"{tf}-201\", \"results.PFM.meme\")\n",
    "modisco_cwm_path = os.path.join(modisco_dir, encid, \"modisco\", f\"modisco_{head_type}\", f\"{head_type}_scores.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4afc63ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Parameters\n",
    "modisco_region_width = 400\n",
    "\n",
    "ref_col = \"reference\"\n",
    "tf_col = \"target\"\n",
    "sw_score_col = \"smith-waterman_score\"\n",
    "sw_align_col = \"smith-waterman_alignment\"\n",
    "sw_logo_col = \"smith-waterman_logo\"\n",
    "\n",
    "sim_threshold = 0\n",
    "overlap_penalty_f1 = -0.5\n",
    "overlap_penalty_f2 = -1.0\n",
    "skip_penalty_f = -0.5\n",
    "skip_penalty_l = -0.01\n",
    "\n",
    "# Set compute options\n",
    "max_chunk = 1600\n",
    "max_cpus = 32\n",
    "use_gpu = True\n",
    "safe = False\n",
    "ic_scale = True\n",
    "fast_plot = True\n",
    "\n",
    "MotifCompendium.set_compute_options(\n",
    "    max_chunk=max_chunk,\n",
    "    max_cpus=max_cpus,\n",
    "    use_gpu=use_gpu,\n",
    "    ic_scale=ic_scale,\n",
    "    fast_plotting=fast_plot,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a564aa45",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load B1scores_H PFM\n",
    "print(\"loading fingers...\")\n",
    "b1h_ppm, metadata = utils_loader.load_pfm(b1h_pfm_path)\n",
    "b1h_ppm_ic = utils_motif.ic_scale(b1h_ppm) # IC scale\n",
    "b1h_ppm_n = unroll_znf_b1h(b1h_ppm) # (F, 3, 4)\n",
    "b1h_ppm_n = utils_motif.ic_scale(b1h_ppm_n)\n",
    "b1h_ppm_n_cp = cp.asarray(b1h_ppm_n)\n",
    "\n",
    "# Load modisco CWM\n",
    "print(\"load modisco...\")\n",
    "(cwm, motif_names, seqlet_counts, posnegs, avgdist_summits) = utils_loader.load_modisco(modisco_cwm_path) # (N, L, 4)\n",
    "cwm = np.abs(cwm) # Abs\n",
    "cwm = utils_motif.ic_scale(cwm) # IC scale\n",
    "\n",
    "# Unroll CWM\n",
    "print(\"unroll CWM...\")\n",
    "unrolled_cwm_i_cp = [unroll_motif_cp(cp.asarray(x), 3) for x in cwm] # N x (L-2, 3, 4)\n",
    "unrolled_importance = [cp.sqrt((x * x).sum(axis=(1, 2))).get() for x in unrolled_cwm_i_cp] # N x (L-2) (L2 norm, of each 3-mer)\n",
    "\n",
    "# Calculate similarities\n",
    "print(\"calculate similarity scores...\")\n",
    "f_orientation_f_sims = [compute_aligned_similarity(b1h_ppm_n_cp, cwm_cp) for cwm_cp in unrolled_cwm_i_cp] # N x (F, L-2)\n",
    "f_orientation_r_sims = [compute_aligned_similarity(b1h_ppm_n_cp[:, ::-1, ::-1], cwm_cp) for cwm_cp in unrolled_cwm_i_cp] # N x (F, L-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b81933",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run\n",
    "(final_score, final_align, final_orient, aligned_motif_fs, H_final, A_final) = modified_smith_waterman_matrix_triplet(\n",
    "    matrix_f=b1h_ppm_n,\n",
    "    matrix_l=cwm[0],\n",
    "    sim_for=score_matrix_for,\n",
    "    sim_rev=score_matrix_rev,\n",
    "    scale_factor=scale_factor,\n",
    "    sim_threshold=sim_threshold,\n",
    "    overlap_penalty_f1=overlap_penalty_f1,\n",
    "    overlap_penalty_f2=overlap_penalty_f2,\n",
    "    skip_penalty_f=skip_penalty_f,\n",
    "    skip_penalty_l=skip_penalty_l,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf1ee5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine motifs\n",
    "combined_aligned_motifs = np.vstack([\n",
    "    np.pad(cwm[0], ((0, aligned_motif_fs.shape[1] - cwm[0].shape[0]), (0, 0)), mode='constant')[np.newaxis, :, :],\n",
    "    aligned_motif_fs,\n",
    "])\n",
    "\n",
    "# Delimit alignments\n",
    "A_final_delimited = [\n",
    "    [ \";\".join(cell) for cell in row ]\n",
    "    for row in A_final\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f0160f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize: Alignment, actions\n",
    "print(f\"Best score: {final_score}\")\n",
    "print(f\"Best alignment: {final_align}\")\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "visualize_motifs_n(combined_aligned_motifs, f\"{output_dir}/aligned_motif_fingers.png\")\n",
    "np.savetxt(f\"{output_dir}/scores.tsv\", H_final, delimiter=\"\\t\")\n",
    "A_final_delimited_df = pd.DataFrame(A_final_delimited)\n",
    "A_final_delimited_df.to_csv(f\"{output_dir}/moves.tsv\", sep=\"\\t\", index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "796b4fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize: Original B1H, CWM\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "visualize_motif(cwm[0], f\"{output_dir}/original_motif.png\")\n",
    "visualize_motif(b1h_ppm[0], f\"{output_dir}/original_finger.png\")\n",
    "\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "visualize_motifs_n(b1h_ppm_n, f\"{output_dir}/original_fingers_n.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be60c3d2-e915-4540-939a-f1c210c0dc34",
   "metadata": {},
   "source": [
    "# Archived"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0af53b23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def smith_waterman_matrix(\n",
    "#         matrix_A, matrix_B, \n",
    "#         score_matrix_for, score_matrix_rev, scale_factor, score_threshold, \n",
    "#         skip_penalty_a, overlap_penalty_a1, overlap_penalty_a2, max_overlap_a,\n",
    "#         skip_penalty_b, extend_penalty_b, max_extend_b):\n",
    "#     '''\n",
    "#     Modified Smith-Waterman algorithm for aligning matrix_A to matrix_B\n",
    "\n",
    "#     --\n",
    "#     Arguments:\n",
    "#         matix_A: (n, l, d)\n",
    "#         matrix_B: (m, d)\n",
    "#         score_matrix_for: (n, m-2)\n",
    "#         score_matrix_rev: (n, m-2)\n",
    "#         scale_factor: (n,)\n",
    "#         score_threshold: float\n",
    "#         skip_penalty_a: float\n",
    "#         skip_penalty_b: float\n",
    "#         extend_penalty_b: float\n",
    "    \n",
    "#     Returns:\n",
    "#         best_score: float\n",
    "#         align_A: (k, d)\n",
    "#         align_B: (k, d)\n",
    "#     '''\n",
    "#     n, m = matrix_A.shape[0], matrix_B.shape[0]\n",
    "#     d = matrix_A.shape[-1]\n",
    "#     assert matrix_A.shape[-1] == matrix_B.shape[-1]\n",
    "\n",
    "#     assert skip_penalty_a <= 0\n",
    "#     assert skip_penalty_b <= 0\n",
    "#     assert extend_penalty_b <= 0\n",
    "    \n",
    "#     best_orient = None\n",
    "#     best_scores_H = None\n",
    "#     best_moves_M = None\n",
    "#     best_score = 0\n",
    "#     best_pos = None\n",
    "\n",
    "#     # Forward, reverse complement\n",
    "#     for for_rev, score_matrix in enumerate([score_matrix_for, score_matrix_rev]):\n",
    "#         # If reverse complement:\n",
    "#         if for_rev == 1:\n",
    "#             scale_factor = scale_factor[::-1]\n",
    "\n",
    "#         # Initialize DP matrices\n",
    "#         scores_H = np.zeros((n+1, m+1))\n",
    "#         moves_M = np.zeros((n+1, m+1), dtype=int)  # 0: stop, 1: diag, 2: up, 3: left\n",
    "#         high_score = 0\n",
    "#         high_pos = None\n",
    "\n",
    "#         # Dynamic programming\n",
    "#         for i in range(1, n + 1):\n",
    "#             for j in range(3, m + 1):  # Need at least 3 residues to compare\n",
    "#                 # Match: Place finger i at position j\n",
    "#                 match_score = score_matrix[i-1, j-3]\n",
    "#                 if match_score > score_threshold:\n",
    "#                     diag = scores_H[i-1, j-3] + (match_score * scale_factor[j-3])\n",
    "#                 else:\n",
    "#                     diag = 0\n",
    "\n",
    "#                 # Match: Penalty for overlapping fingers\n",
    "#                 if np.any(moves_M[i-1, j-3] == 1):\n",
    "#                     diag += overlap_penalty_a1\n",
    "\n",
    "#                 # Up: Skip in matrix_A (Skip a finger)\n",
    "#                 up = scores_H[i-1, j] + skip_penalty_a\n",
    "\n",
    "#                 # Left: Skip in matrix_B (Skip a CWM base, between fingers)\n",
    "#                 left = scores_H[i, j-3] + skip_penalty_b\n",
    "#                 # Left: Penalty for extending gap between fingers\n",
    "#                 if moves_M[i, j-3] == 3:\n",
    "#                     left += extend_penalty_b\n",
    "\n",
    "#                 # Choose best\n",
    "#                 options = np.array([0, diag, up, left])\n",
    "#                 scores_H[i, j] = options.max()\n",
    "#                 moves_M[i, j] = options.argmax()\n",
    "\n",
    "#                 # Record best score, best position\n",
    "#                 if scores_H[i, j] > high_score:\n",
    "#                     high_score = scores_H[i, j]\n",
    "#                     high_pos = (i, j)\n",
    "        \n",
    "#         # Record forward, reverse complement\n",
    "#         if high_score > best_score:\n",
    "#             best_orient = for_rev\n",
    "#             best_scores_H = scores_H\n",
    "#             best_moves_M = moves_M\n",
    "#             best_score = high_score\n",
    "#             best_pos = high_pos\n",
    "\n",
    "#     ## Align\n",
    "#     # Reverse complement\n",
    "#     if best_orient == 1:\n",
    "#         matrix_B = matrix_B[::-1, ::-1]\n",
    "#     i, j = best_pos\n",
    "#     align_A = np.zeros((0, d))\n",
    "#     align_B = np.zeros((0, d))\n",
    "\n",
    "#     # Fill out unaligned parts (after; forward: left to right)\n",
    "#     added_right_i = 0\n",
    "#     added_right_j = 0\n",
    "#     while i < n + 1:\n",
    "#         align_A = np.vstack([matrix_A[i-1][::-1], align_A])\n",
    "#         i += 1\n",
    "#         added_right_i += 1\n",
    "#     while j < m + 1:\n",
    "#         align_B = np.vstack([matrix_B[j-1], align_B])\n",
    "#         j += 1\n",
    "#         added_right_j += 1\n",
    "    \n",
    "#     # Extend less added with zero\n",
    "#     added_right_diff = 3*added_right_i - added_right_j\n",
    "#     if added_right_diff > 0:  # More added in A\n",
    "#         align_B = np.vstack([align_B, np.zeros((added_right_diff, d))])\n",
    "#     if added_right_diff < 0:  # More added in B\n",
    "#         align_A = np.vstack([align_A, np.zeros((-added_right_diff, d))])\n",
    "\n",
    "#     # Traceback (backwards: right to left)\n",
    "#     i, j = best_pos\n",
    "#     while i > 0 and j > 0 and best_scores_H[i, j] > 0:\n",
    "#         move = best_moves_M[i, j]\n",
    "#         # Match\n",
    "#         if move == 1:\n",
    "#             # scaled_submatrix_A = matrix_A[i-1] * np.linalg.norm(matrix_B[j-3:j]) / np.linalg.norm(matrix_A[i-1])\n",
    "#             scaled_submatrix_A = matrix_A[i-1]\n",
    "#             align_A = np.vstack([align_A, scaled_submatrix_A[::-1]])\n",
    "#             align_B = np.vstack([align_B, matrix_B[j-3:j][::-1]])\n",
    "#             i -= 1\n",
    "#             j -= 3\n",
    "#         # Up: Skip in matrix_A (Skip a finger)\n",
    "#         elif move == 2:\n",
    "#             # scaled_submatrix_A = matrix_A[i-1] * np.linalg.norm(matrix_B[j-3:j]) / np.linalg.norm(matrix_A[i-1])\n",
    "#             scaled_submatrix_A = matrix_A[i-1]\n",
    "#             align_A = np.vstack([align_A, scaled_submatrix_A[::-1]])\n",
    "#             align_B = np.vstack([align_B, np.zeros((3, d))])\n",
    "#             i -= 1\n",
    "#         # Left: Skip in matrix_B (Skip a CWM base, between fingers)\n",
    "#         elif move == 3:\n",
    "#             align_A = np.vstack([align_A, np.zeros((1, d))])\n",
    "#             align_B = np.vstack([align_B, matrix_B[j-1]])\n",
    "#             j -= 1\n",
    "    \n",
    "#     # Fill out remaining unaligned parts (before; backward: right to left)\n",
    "#     added_left_i = 0\n",
    "#     added_left_j = 0\n",
    "#     while i > 0:\n",
    "#         align_A = np.vstack([align_A, matrix_A[i-1][::-1]])\n",
    "#         i -= 1\n",
    "#         added_left_i += 1\n",
    "#     while j > 0:\n",
    "#         align_B = np.vstack([align_B, matrix_B[j-1]])\n",
    "#         j -= 1\n",
    "#         added_left_j += 1\n",
    "\n",
    "#     # Extend to match lengths\n",
    "#     left_diff = align_A.shape[0] - align_B.shape[0]\n",
    "#     if left_diff > 0:  # Longer A\n",
    "#         align_B = np.vstack([align_B, np.zeros((left_diff, d))])\n",
    "#     if left_diff < 0:  # Longer B\n",
    "#         align_A = np.vstack([align_A, np.zeros((-left_diff, d))])\n",
    "\n",
    "#     # Flip order, since traceback goes from end to start\n",
    "#     align_A = align_A[::-1]\n",
    "#     align_B = align_B[::-1]\n",
    "\n",
    "#     return best_score, align_A, align_B, best_orient, best_scores_H, best_moves_M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24a4213f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Run\n",
    "# (best_score, align_A, align_B, best_orient, best_scores_H, best_moves_M) = smith_waterman_matrix(\n",
    "#     matrix_A=b1h_ppm,\n",
    "#     matrix_B=cwm[0],\n",
    "#     score_matrix_for=score_matrix_for,\n",
    "#     score_matrix_rev=score_matrix_rev,\n",
    "#     scale_factor=scale_factor,\n",
    "#     score_threshold=score_threshold,\n",
    "#     skip_penalty_a=skip_penalty_a,\n",
    "#     overlap_penalty_a1=overlap_penalty_a1,\n",
    "#     overlap_penalty_a2=overlap_penalty_a2,\n",
    "#     max_overlap_a=max_overlap_a,\n",
    "#     skip_penalty_b=skip_penalty_b,\n",
    "#     extend_penalty_b=extend_penalty_b,\n",
    "#     max_extend_b=max_extend_b,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2118b609",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def generalized_align(seq_f, seq_l, f_sims, r_sims, imp, min_sim, finger_gap_1, finger_gap_n, cwm_gap_1, cwm_gap_n, align_type='local'):\n",
    "#     \"\"\"\n",
    "#     Computes the optimal local alignment between two sequences.\n",
    "    \n",
    "#     Arguments:\n",
    "#         `seq_f`: an indexable sequence of objects to align to `seq_l`\n",
    "#         `seq_l`: an indexable sequence of objects to align to `seq_f`\n",
    "#         `score_func`: a function that takes in an object from `seq_f` and an object from\n",
    "#             `seq_l` (in that order) and returns a scalar alignment score\n",
    "#         `gap_open`: penalty for opening a gap; must be non-positive\n",
    "#             'finger_gap_1': penalty for opening a gap in seq_f; must be non-positive\n",
    "#             'cwm_gap_1': penalty for opening a gap in seq_l; must be\n",
    "#         `gap_extend`: penalty for opening a gap; must be non-positive\n",
    "#             'finger_gap_n': penalty for extending a gap in seq_f; must be non-positive\n",
    "#             'cwm_gap_n': penalty for extending a gap in seq_l; must be non-positive\n",
    "#         `local_align`: if true, perform a local alignment instead of a global alignment\n",
    "#             if true: Smith-Waterman algorithm\n",
    "#             if false: Needleman-Wunsch algorithm\n",
    "    \n",
    "#     Returns:\n",
    "#         Returns the score of the best alignment, and the best alignment. The alignment is\n",
    "#         returned as a list of paired indices, denoting which indices are aligned between\n",
    "#         `seq_f` and `seq_l`. \"-\" indicates a gap. The returned indices are 0-indexed.\n",
    "#     \"\"\"\n",
    "#     assert finger_gap_1 <= 0\n",
    "#     assert finger_gap_n <= 0\n",
    "#     assert cwm_gap_1 <= 0\n",
    "#     assert cwm_gap_n <= 0\n",
    "#     assert align_type in ['local', 'global']\n",
    "#     F, L = f_sims.shape\n",
    "\n",
    "#     ## Forward, reverse complement\n",
    "#     for_rev = []\n",
    "#     for sims in [f_sims, r_sims]: # (F, L)\n",
    "    \n",
    "#         # Define matrices\n",
    "#         M_match = np.zeros((F+1, L+3))  # Match\n",
    "#         I_insert_f = np.zeros((F+1, L+3))  # Insertion in seq_f: B1scores_H Fingers\n",
    "#         I_insert_l = np.zeros((F+1, L+3))  # Insertion in seq_l: BPNet CWM\n",
    "        \n",
    "#         V_score = np.zeros((F+1, L+3))  # Final scores\n",
    "#         P_path = np.zeros((F+1, L+3), dtype=int)  # Paths (0: match, 1: insert_f, 2: insert_l)\n",
    "        \n",
    "#         ## Initialize matrix\n",
    "#         # Global alignment: Needleman-Wunsch\n",
    "#         if align_type == 'global':\n",
    "#             for j in range(1, F+1):\n",
    "#                 I_insert_f[0, j] = finger_gap_1+((j-1) * finger_gap_n)\n",
    "#                 P_path[0, j] = 1\n",
    "#             for i in range(1, L+1):\n",
    "#                 I_insert_l[i, 0] = cwm_gap_1+((i-1) * cwm_gap_n)\n",
    "#                 P_path[i, 0] = 2\n",
    "        \n",
    "#         ## Dynamic programming\n",
    "#         # Fingers, F = seq_f\n",
    "#         for f in range(1, F+1):\n",
    "#             # CWM Length, L = seq_l\n",
    "#             for l in range(1, L+1):\n",
    "#                 ## M_match: Match Finger[i] to CWM[j] (Diagonal move): \n",
    "#                 M_match[f, l] = max(\n",
    "#                     (I_insert_f[f-1, l-1]+sims[f-1, l-1]),  # From previous: Finger gap = Match score\n",
    "#                     (I_insert_l[f-1, l-1]+sims[f-1, l-1]),  # From previous: CWM gap = Match score\n",
    "#                     (M_match[f-1, l-1]+sims[f-1, l-1]+finger_overlap_2),  # From previous: Match = Match score+2bp finger overlap penalty\n",
    "#                     (M_match[f-1, l-2]+sims[f-1, l-1]+finger_overlap_1),  # From previous: Match = Match score+1bp finger overlap penalty\n",
    "#                     (M_match[f-1, l-3]+sims[f-1, l-1]),  # From previous: Match = Match score+1bp finger overlap penalty\n",
    "#                 )\n",
    "\n",
    "#                 ## I_insert_f: Insert gap in Finger (scores_Horizontal move):\n",
    "#                 I_insert_f[f, l] = max(\n",
    "#                     (I_insert_f[f, l-1]+finger_gap_n),  # Previous from: Finger gap = Finger gap penalty limit (far-reaching finger: limit)\n",
    "#                     (I_insert_l[f, l-1]+finger_gap_1),  # Previous from: CWM gap = 1bp Finger gap penalty (?)\n",
    "#                     (M_match[f, l-1]),  # Previous from: Match = 1 bp Within matched finger (normal)\n",
    "#                     (M_match[f, l-2]),  # Previous from: Match = 2 bp Within matched finger (normal)\n",
    "#                     (M_match[f, l-3]+finger_gap_1),  # Previous from: Match = 1bp Outside matched finger (far-reaching finger)\n",
    "#                     (M_match[f, l-4]+(2 * finger_gap_1)),  # Previous from: Match = 2 bp Outside matched finger (far-reaching finger)\n",
    "#                     (M_match[f, l-5]+(3 * finger_gap_1)),  # Previous from: Match = 3 bp Outside matched finger (far-reaching finger)\n",
    "#                 )\n",
    "\n",
    "#                 ## I_insert_l: Insert gap in CWM (Vertical move):\n",
    "#                 I_insert_l[f, l] = max(\n",
    "#                     (I_insert_f[f-1, l]+cwm_gap_1),  # From previous: Finger gap = 1bp CWM gap penalty\n",
    "#                     (I_insert_l[f-1, l]+cwm_gap_n),  # From previous: scores_Horizontal gap = >2bp CWM gap penalty\n",
    "#                     (M_match[f-1, l]+cwm_gap_1),  # From previous: Match = \n",
    "#                 )\n",
    "\n",
    "#                 # Local alignment: Min score is 0 (reset score; start new local alignment)\n",
    "#                 if align_type == 'local':\n",
    "#                     M_match[i, j] = max(M_match[i, j], 0)\n",
    "#                     I_insert_f[i, j] = max(I_insert_f[i, j], 0)\n",
    "#                     I_insert_l[i, j] = max(I_insert_l[i, j], 0)\n",
    "\n",
    "#                 # Select best score and path\n",
    "#                 scores = [M_match[i, j], I_insert_f[i, j], I_insert_l[i, j]]\n",
    "#                 P_path[i, j] = np.argmax(scores)\n",
    "#                 V_score[i, j] = scores[P_path[i, j]]\n",
    "        \n",
    "#         # Store forward, reverse complement\n",
    "#         for_rev.append((V_score, P_path))\n",
    "    \n",
    "#     # Select forward, reverse complement\n",
    "#     if for_rev[0][0].max() >= for_rev[1][0].max():\n",
    "#         # Forward\n",
    "#         V_score, P_path = for_rev[0]\n",
    "#     else:\n",
    "#         # Reverse complement\n",
    "#         V_score, P_path = for_rev[1]\n",
    "\n",
    "#     # Identify best score and starting point\n",
    "#     if align_type == 'local':\n",
    "#         i, j = np.unravel_index(np.argmax(V_score), V_score.shape)\n",
    "#         traceback_done = lambda i, j: V_score[i, j] == 0\n",
    "#     else:\n",
    "#         i, j = F, L\n",
    "#         traceback_done = lambda i, j: i == 0 and j == 0\n",
    "    \n",
    "#     # Trace back best alignment\n",
    "#     final_score = V_score[i, j]\n",
    "#     alignment = []\n",
    "#     aligned_seq_f = []\n",
    "#     aligned_seq_l = []\n",
    "\n",
    "#     while not traceback_done(i, j):\n",
    "#         if P_path[i, j] == 0:\n",
    "#             # Match/mismatch\n",
    "#             i -= 1\n",
    "#             j -= 1\n",
    "#             alignment.append((i, j))\n",
    "#             aligned_seq_f.append(seq_f[i])\n",
    "#             aligned_seq_l.append(seq_l[j])\n",
    "#         elif P_path[i, j] == 1:\n",
    "#             # Insertion in seq_f (gap in seq_l)\n",
    "#             j -= 1\n",
    "#             alignment.append((\"-\", j))\n",
    "#             aligned_seq_f.append(\"-\")\n",
    "#             aligned_seq_l.append(seq_l[j])\n",
    "#         elif P_path[i, j] == 2:\n",
    "#             # Insertion in seq_l (gap in seq_f)\n",
    "#             i -= 1\n",
    "#             alignment.append((i, \"-\"))\n",
    "#             aligned_seq_f.append(seq_f[i])\n",
    "#             aligned_seq_l.append(\"-\")\n",
    "\n",
    "#     # Reverse for correct orientation\n",
    "#     alignment.reverse()\n",
    "#     aligned_seq_f.reverse()\n",
    "#     aligned_seq_l.reverse()\n",
    "\n",
    "#     return final_score, {\n",
    "#         \"index_pairs\": alignment,\n",
    "#         \"aligned_seq_f\": \"\".join(aligned_seq_f),\n",
    "#         \"aligned_seq_l\": \"\".join(aligned_seq_l),\n",
    "#     }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae4e11dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_b1h_modisco_similarities(b1h_pfm_path, modisco_cwm_path, output_dir):    \n",
    "#     # Load fingers\n",
    "#     print(\"loading fingers...\")\n",
    "#     b1h_ppm = load_znf_b1h(b1h_pfm_path) # (F, 3, 4)\n",
    "#     b1h_ppm = utils_motif.ic_scale(b1h_ppm) # IC scale\n",
    "#     b1h_ppm_cp = cp.asarray(b1h_ppm)\n",
    "#     # Load modisco\n",
    "#     print(\"load modisco...\")\n",
    "#     motifs, metadata = utils_loader.load_modisco(modisco_cwm_path)\n",
    "#     motifs = np.abs(motifs) # Abs\n",
    "#     motifs = utils_motif.ic_scale(motifs) # IC scale\n",
    "#     motifs /= np.linalg.norm(motifs, axis=(1, 2), keepdims=True) # L2 normalize\n",
    "#     # Unroll motifs\n",
    "#     print(\"unroll motifs...\")\n",
    "#     unrolled_motifs_cp = [unroll_motif_cp(cp.asarray(x), 3) for x in motifs] # (L-2, 3, 4) x N\n",
    "#     # unrolled_motifs_cp = [unrolled_motifs_cp[0]] # NOTE: FOR NOW, JUST MOTIF 0 FROM MODISCO\n",
    "#     motif_importance = [(x*x).sum(axis=(1, 2)).get() for x in unrolled_motifs_cp] # (L-2) x N\n",
    "\n",
    "#     # Forward orientation, Forward finger similarity\n",
    "#     print(\"forward orientation, forward finger similarity...\")\n",
    "#     f_orientation_f_sims = [compute_aligned_similarity(b1h_ppm_cp, x) for x in unrolled_motifs_cp] # (F, L-2) x N\n",
    "#     _show_heatmap(f_orientation_f_sims[1], f\"{output_dir}_heatmap.png\")\n",
    "#     # Forward orientation, Reverse finger similarity\n",
    "#     print(\"forward orientation, reverse finger similarity...\")\n",
    "#     f_orientation_r_sims = [compute_aligned_similarity(b1h_ppm_cp[:, ::-1, ::-1], x) for x in unrolled_motifs_cp] # (F, L-2) x N\n",
    "#     # Reverse orientation, Forward finger similarity\n",
    "#     # print(\"reverse orientation, forward finger similarity...\")\n",
    "#     # r_orientation_f_sims = [x[::-1, :] for x in f_orientation_f_sims] # (F, L-2) x N\n",
    "#     # Reverse orientation, Reverse finger similarity\n",
    "#     # print(\"reverse orientation, reverse finger similarity...\")\n",
    "#     # r_orientation_r_sims = [x[::-1, :] for x in f_orientation_r_sims] # (F, L-2) x N\n",
    "#     # Dynamic program\n",
    "#     for i in range(motifs.shape[0]):\n",
    "#         # Forward DP\n",
    "#         print(\"forward DP\")\n",
    "#         f_alignment = dynamic_alignment(f_orientation_f_sims[i], f_orientation_r_sims[i], motif_importance[i])[-1][-1]\n",
    "#         # Visualize forward alignment\n",
    "#         print(\"visualize fwd\")\n",
    "#         visualize_alignment(motifs[i], b1h_ppm, f_alignment, f\"{output_dir}_motif{i}_align.png\")\n",
    "#         '''\n",
    "#         # Reverse DP\n",
    "#         print(\"reverse DP\")\n",
    "#         r_alignment = dynamic_alignment(r_orientation_f_sims[i], r_orientation_r_sims[i], motif_importance[i])[-1][-1]\n",
    "#         # Modify reverse alignment\n",
    "#         alignment_score, alignment_code = r_alignment\n",
    "#         F = b1h_ppm.shape[0]\n",
    "#         new_alignment_code = []\n",
    "#         for a in alignment_code:\n",
    "#             a_finger, a_loc = a.split(\"@\")\n",
    "#             if a_finger.endswith(\"r\"):\n",
    "#                 new_alignment_code.append(f\"f{F-1-int(a_finger[1:-1])}r@{a_loc}\")\n",
    "#             else:\n",
    "#                 new_alignment_code.append(f\"f{F-1-int(a_finger[1:])}@{a_loc}\")\n",
    "#         r_alignment = (alignment_score, new_alignment_code)\n",
    "#         '''"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
