
In [1]:
import os
import sys
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import scipy.signal
import scipy.cluster.hierarchy
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app
0it [00:00, ?it/s]
In [2]:
# Plotting defaults
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "": "Roboto",
    "font.weight": "bold"
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ MatplotlibDeprecationWarning: 
The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead.
  after removing the cwd from sys.path.

Define constants and paths

In [3]:
# Define parameters/fetch arguments
motif_files = os.environ["TFM_MOTIF_FILES"].split(",")
group_names = os.environ["TFM_GROUP_NAMES"].split(",")
if "TFM_HEATMAP_CACHE" in os.environ:
    tfm_heatmap_cache_dir = os.environ["TFM_HEATMAP_CACHE"]
    tfm_heatmap_cache_dir = None
assert len(motif_files) == len(group_names)
assert len(group_names) == len(set(group_names))

print("Motif files: %s" % motif_files)
print("Group names: %s" % group_names)
print("Saved heatmap cache: %s" % tfm_heatmap_cache_dir)
Motif files: ['/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold9/JUND_multitask_profile_fold9_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold9/JUND_multitask_profile_fold9_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold8/JUND_multitask_profile_fold8_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold8/JUND_multitask_profile_fold8_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold7/JUND_multitask_profile_fold7_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold7/JUND_multitask_profile_fold7_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold6/JUND_multitask_profile_fold6_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold6/JUND_multitask_profile_fold6_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold5/JUND_multitask_profile_fold5_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold5/JUND_multitask_profile_fold5_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold4/JUND_multitask_profile_fold4_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold4/JUND_multitask_profile_fold4_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold3/JUND_multitask_profile_fold3_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold3/JUND_multitask_profile_fold3_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold2/JUND_multitask_profile_fold2_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold2/JUND_multitask_profile_fold2_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold10/JUND_multitask_profile_fold10_count/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold10/JUND_multitask_profile_fold10_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold1/JUND_multitask_profile_fold1_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile/JUND_multitask_profile_fold1/JUND_multitask_profile_fold1_count/all_motifs.h5']
Group names: ['JUND_F9_C', 'JUND_F9_P', 'JUND_F8_C', 'JUND_F8_P', 'JUND_F7_C', 'JUND_F7_P', 'JUND_F6_C', 'JUND_F6_P', 'JUND_F5_C', 'JUND_F5_P', 'JUND_F4_P', 'JUND_F4_C', 'JUND_F3_P', 'JUND_F3_C', 'JUND_F2_P', 'JUND_F2_C', 'JUND_F10_C', 'JUND_F10_P', 'JUND_F1_P', 'JUND_F1_C']
Saved heatmap cache: /users/amtseng/tfmodisco/results/reports/motif_heatmaps/cache/multitask_profile/JUND_multitask_profile
In [4]:
# Define constants
cluster_color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
default_cluster_color = "gray"
In [5]:
if tfm_heatmap_cache_dir:
    os.makedirs(tfm_heatmap_cache_dir, exist_ok=True)

Helper functions

For plotting and organizing things

In [6]:
def import_motifs(motif_files, group_names):
    Imports a set of motifs from the saved HDF5 files.
    `group_names` is a list of group names, one for each motif file.
    Returns a list of motifs as L x 4 arrays, a parallel list of
    motif names, and a dictionary mapping group names to lists of
    motif names.
    motifs, motif_names = [], []
    groups = {}
    for motif_file, stem in zip(motif_files, group_names):
        groups[stem] = []
        with h5py.File(motif_file, "r") as f:
            for key in f.keys():
                motif_name = "%s:%s" % (stem, key)
    return motifs, motif_names, groups
In [7]:
def motif_similarity_score(motif_1, motif_2, average=False, align_to_longer=True):
    Computes the motif similarity score between two motifs by
    the summed cosine similarity, maximized over all possible sliding
    windows. Also returns the index relative to the start of `motif_2`
    where `motif_1` should be placed to maximize this score.
    If `average` is True, then use average of similarity of overlap.
    If `align_to_longer` is True, always use the longer motif as the basis
    for the index computation (if tie use `motif_2`). Otherwise, always use
    # Normalize
    motif_1 = motif_1 - np.mean(motif_1, axis=1, keepdims=True)
    motif_2 = motif_2 - np.mean(motif_2, axis=1, keepdims=True)
    motif_1 = motif_1 / np.sqrt(np.sum(motif_1 * motif_1, axis=1, keepdims=True))
    motif_2 = motif_2 / np.sqrt(np.sum(motif_2 * motif_2, axis=1, keepdims=True))
    # Always make motif_2 longer
    if align_to_longer and len(motif_1) > len(motif_2):
        motif_1, motif_2 = motif_2, motif_1
    # Pad motif_2 by len(motif_1) - 1 on either side
    orig_motif_2_len = len(motif_2)
    pad_size = len(motif_1) - 1
    motif_2 = np.pad(motif_2, ((pad_size, pad_size), (0, 0)))
    if average:
        # Compute overlap sizes
        overlap_sizes = np.empty(orig_motif_2_len + pad_size)
        overlap_sizes[:pad_size] = np.arange(1, len(motif_1))
        overlap_sizes[-pad_size:] = np.flip(np.arange(1, len(motif_1)))
        overlap_sizes[pad_size:-pad_size] = len(motif_1)
    # Compute similarities across all sliding windows
    scores = np.empty(orig_motif_2_len + pad_size)
    for i in range(orig_motif_2_len + pad_size):
        scores[i] = np.sum(motif_1 * motif_2[i : i + len(motif_1)])
    if average:
        scores = scores / overlap_sizes
    return np.max(scores), np.argmax(scores) - pad_size
In [8]:
def compute_similarity_matrix(motifs, show_progress=True):
    Computes a similarity matrix over the pairs of motifs using cross
    correlation. `motifs` is a list of N motifs, where each is an L x 4
    array (may be different Ls).
    Returns an N x N array of distances.
    num_motifs = len(motifs)
    sim_matrix = np.empty((num_motifs, num_motifs))
    t_iter = tqdm.notebook.trange(num_motifs) if show_progress else range(num_motifs)
    for i in t_iter:
        for j in range(i, num_motifs):
            sim, _ = motif_similarity_score(motifs[i], motifs[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim
    return sim_matrix
In [9]:
def compute_clusters(linkage, goal_clusters, tolerance=(-2, 2), start=50, max_iter=10):
    From a linkage map, computes clusters with a goal of `goal_clusters` clusters.
    Will allow the given tolerance. `start` is what distance threshold to check first.
    `max_iter` is the maximum number of checks to do.
    Returns the clustering.
    clusters = scipy.cluster.hierarchy.fcluster(
        linkage, start, criterion="distance"
    num_clusters = len(np.unique(clusters))
    if num_clusters > goal_clusters + tolerance[1] and max_iter:
        return compute_clusters(linkage, goal_clusters, tolerance, start * 2, max_iter - 1)
    elif num_clusters < goal_clusters - tolerance[0] and max_iter:
        return compute_clusters(linkage, goal_clusters, tolerance, start / 2, max_iter - 1)
        return clusters
In [10]:
def plot_heatmap(sim_matrix, labels, linkage, clusters):
    Given a similariy matrix and labels, plots a heatmap with
    dendrogram. `linkage` is the linkage map computed on the matrix, and
    `clusters` is the cluster ID of each entry.
    Returns the figure and the indices in which to order the entries.
    fig, ax = plt.subplots(
        nrows=2, ncols=2, figsize=(20, 20),
            "width_ratios": [20, 1],
            "height_ratios": [1, 4],
            "hspace": 0,
            "wspace": 0.1
    # Compute the color of every link based on cluster assignments
    # Adapted from
    leaf_colors = [cluster_color_cycle[i % len(cluster_color_cycle)] for i in clusters]
    link_colors = {}
    for i, i_link in enumerate(linkage[:, :2].astype(int)):
        color_0 = link_colors[i_link[0]] if i_link[0] > len(linkage) else leaf_colors[i_link[0]]
        color_1 = link_colors[i_link[1]] if i_link[1] > len(linkage) else leaf_colors[i_link[1]]
        link_colors[i + 1 + len(linkage)] = color_0 if color_0 == color_1 else default_cluster_color

    dend = scipy.cluster.hierarchy.dendrogram(
        linkage, ax=ax[0, 0], link_color_func=(lambda x: link_colors[x])

    order_inds = dend["leaves"]
    sim_matrix_reordered = sim_matrix[:, order_inds][order_inds, :]
    heatmap = ax[1, 0].imshow(sim_matrix_reordered, aspect="auto", cmap="Blues")
    ax[1, 0].set_yticks([])
    ax[1, 0].set_xticks(range(len(sim_matrix)))
    ax[1, 0].set_xticklabels(np.array(labels)[order_inds], rotation=90, fontsize=10)

    fig.colorbar(heatmap, cax=ax[1, 1])

    ax[0, 0].axis("off")
    ax[0, 1].axis("off")
    return fig, order_inds
In [11]:
def aggregate_motifs(motifs):
    Aggregates a list of L x 4 (not all the same L) motifs into a single
    L x 4 motif.
    # Compute similarity matrix
    sim_matrix = compute_similarity_matrix(motifs, show_progress=False)

    # Sort motifs by how similar it is to everyone else
    inds = np.flip(np.argsort(np.sum(sim_matrix, axis=0)))
    # Have the consensus start with the most similar
    consensus = np.zeros_like(motifs[inds[0]])
    consensus = consensus + motifs[inds[0]]
    # For each successive motif, add it into the consensus
    for i in inds[1:]:
        motif = motifs[i]
        _, index = motif_similarity_score(motif, consensus, align_to_longer=False)
        if index >= 0:
            start, end = index, index + len(motif)
            consensus[start:end] = consensus[start:end] + motif[:len(consensus) - index]
            end = len(motif) + index
            consensus[:end] = consensus[:end] + motif[-index:-index + len(consensus)]
    return consensus / len(motifs)

Show motifs clusters

For all of the aggregated motifs, show the motif clusters.

In [12]:
motifs, motif_names, motif_groups = import_motifs(motif_files, group_names)
In [13]:
# Flip all motifs to be the purine-rich version
for i, motif in enumerate(motifs):
    if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
        motifs[i] = np.flip(motif)
In [14]:
# Compute similarity matrix
sim_matrix = compute_similarity_matrix(motifs)

In [15]:
# Compute linkage
linkage = scipy.cluster.hierarchy.linkage(sim_matrix, method="ward")
In [16]:
# Compute clusters
expected_clusters = np.max([len(m) for m in motif_groups.values()])
clusters = compute_clusters(linkage, expected_clusters, max_iter=100)
In [17]:
display(vdomh.h4("Number of motifs: %d" % len(clusters)))
display(vdomh.h4("Number of clusters: %d" % len(np.unique(clusters))))

Number of motifs: 304

Number of clusters: 71

In [18]:
# Plot heatmap
fig, order_inds = plot_heatmap(sim_matrix, motif_names, linkage, clusters)
if tfm_heatmap_cache_dir:
    fig.savefig(os.path.join(tfm_heatmap_cache_dir, "motif_cluster_heatmap.png"))
In [19]:
# Show aggregated and constituent motifs for each cluster
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "50%"}),
    vdomh.col(style={"width": "50%"})

header = vdomh.thead("Aggregate motif", style={"text-align": "center"}),"Constituent motifs", style={"text-align": "center"})

all_consensus = {}
cluster_ids, counts = np.unique(clusters, return_counts=True)
cluster_ids = cluster_ids[np.flip(np.argsort(counts))]
for i, cluster_id in enumerate(cluster_ids):
    match_inds = np.where(clusters == cluster_id)[0]
    matches = [motifs[j] for j in match_inds]
    match_names = [motif_names[j] for j in match_inds]
    consensus = aggregate_motifs(matches)
    all_consensus[cluster_id] = consensus
    display(vdomh.h3("Cluster %d (%d/%d)" % (cluster_id, i + 1, len(cluster_ids))))
    display(vdomh.h4("%d motifs" % len(matches)))
    agg_fig = viz_sequence.plot_weights(consensus, figsize=(20, 4), return_fig=True)
    const_figs = []
    for motif, motif_name in zip(matches, match_names):
        fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)

    body = vdomh.tbody(,*const_figs)))
    display(vdomh.table(colgroup, header, body))
    if tfm_heatmap_cache_dir:
        agg_fig.savefig(os.path.join(tfm_heatmap_cache_dir, "cluster_%d_aggregate_motif.png" % cluster_id))

if tfm_heatmap_cache_dir:
    with h5py.File(os.path.join(tfm_heatmap_cache_dir, "motif_clusters.h5"), "w") as f:
        f.create_dataset("motif_names", data=np.array(motif_names).astype("S"), compression="gzip")
        all_motifs = f.create_group("all_motifs")
        for name, motif in zip(motif_names, motifs):
            all_motifs.create_dataset(name, data=motif, compression="gzip")
        f.create_dataset("similarity_matrix", data=sim_matrix, compression="gzip")
        f.create_dataset("dendrogram_order", data=order_inds, compression="gzip")
        f.create_dataset("clusters", data=clusters, compression="gzip")
        agg_motifs = f.create_group("cluster_motifs")
        for cluster_id, consensus in all_consensus.items():
            agg_motifs.create_dataset(str(cluster_id), data=consensus, compression="gzip")

Cluster 34 (1/71)

21 motifs

/users/amtseng/tfmodisco/src/plot/ RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  fig = plt.figure(figsize=figsize)
Aggregate motifConstituent motifs

Cluster 27 (2/71)

17 motifs

Aggregate motifConstituent motifs

Cluster 31 (3/71)

17 motifs

Aggregate motifConstituent motifs

Cluster 51 (4/71)

16 motifs

Aggregate motifConstituent motifs

Cluster 47 (5/71)

15 motifs

Aggregate motifConstituent motifs

Cluster 66 (6/71)

11 motifs

Aggregate motifConstituent motifs

Cluster 50 (7/71)

10 motifs

Aggregate motifConstituent motifs

Cluster 49 (8/71)

9 motifs

Aggregate motifConstituent motifs

Cluster 65 (9/71)

8 motifs

Aggregate motifConstituent motifs

Cluster 54 (10/71)

7 motifs

Aggregate motifConstituent motifs

Cluster 25 (11/71)

7 motifs

Aggregate motifConstituent motifs

Cluster 46 (12/71)

7 motifs

Aggregate motifConstituent motifs

Cluster 45 (13/71)

7 motifs

Aggregate motifConstituent motifs

Cluster 57 (14/71)

6 motifs

Aggregate motifConstituent motifs

Cluster 64 (15/71)

6 motifs

Aggregate motifConstituent motifs

Cluster 69 (16/71)

6 motifs

Aggregate motifConstituent motifs

Cluster 1 (17/71)

6 motifs

Aggregate motifConstituent motifs

Cluster 48 (18/71)

5 motifs

Aggregate motifConstituent motifs

Cluster 42 (19/71)

5 motifs

Aggregate motifConstituent motifs

Cluster 55 (20/71)

5 motifs

Aggregate motifConstituent motifs

Cluster 61 (21/71)

5 motifs

Aggregate motifConstituent motifs

Cluster 52 (22/71)

4 motifs

Aggregate motifConstituent motifs

Cluster 56 (23/71)

4 motifs

Aggregate motifConstituent motifs

Cluster 14 (24/71)

4 motifs

Aggregate motifConstituent motifs

Cluster 44 (25/71)

4 motifs

Aggregate motifConstituent motifs

Cluster 32 (26/71)

4 motifs

Aggregate motifConstituent motifs

Cluster 24 (27/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 23 (28/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 12 (29/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 30 (30/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 4 (31/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 3 (32/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 36 (33/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 58 (34/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 68 (35/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 40 (36/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 43 (37/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 53 (38/71)

3 motifs

Aggregate motifConstituent motifs

Cluster 5 (39/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 13 (40/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 62 (41/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 15 (42/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 9 (43/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 17 (44/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 8 (45/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 37 (46/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 7 (47/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 21 (48/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 71 (49/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 67 (50/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 26 (51/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 28 (52/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 29 (53/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 41 (54/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 2 (55/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 35 (56/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 70 (57/71)

2 motifs

Aggregate motifConstituent motifs

Cluster 6 (58/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 11 (59/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 10 (60/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 38 (61/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 63 (62/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 16 (63/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 18 (64/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 20 (65/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 22 (66/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 60 (67/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 59 (68/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 33 (69/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 39 (70/71)

1 motifs

Aggregate motifConstituent motifs

Cluster 19 (71/71)

1 motifs

Aggregate motifConstituent motifs