Results

In [1]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import scipy.signal
import scipy.cluster.hierarchy
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:15: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app
Out[1]:
0it [00:00, ?it/s]
In [2]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold"
}
plt.rcParams.update(plot_params)
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarning: 
The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead.
  after removing the cwd from sys.path.

Define constants and paths

In [3]:
# Define parameters/fetch arguments
motif_files = os.environ["TFM_MOTIF_FILES"].split(",")
group_names = os.environ["TFM_GROUP_NAMES"].split(",")
if "TFM_HEATMAP_CACHE" in os.environ:
    tfm_heatmap_cache_dir = os.environ["TFM_HEATMAP_CACHE"]
else:
    tfm_heatmap_cache_dir = None
    
assert len(motif_files) == len(group_names)
assert len(group_names) == len(set(group_names))

print("Motif files: %s" % motif_files)
print("Group names: %s" % group_names)
print("Saved heatmap cache: %s" % tfm_heatmap_cache_dir)
Motif files: ['/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile_finetune/NR3C1-reddytime_multitask_profile_finetune_fold5/NR3C1-reddytime_multitask_profile_finetune_fold5_profile/all_motifs.h5', '/users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile_finetune/NR3C1-reddytime_multitask_profile_finetune_fold5/NR3C1-reddytime_multitask_profile_finetune_fold5_count/all_motifs.h5']
Group names: ['NR3C1-reddytime_F5_P', 'NR3C1-reddytime_F5_C']
Saved heatmap cache: /users/amtseng/tfmodisco/results/reports/motif_heatmaps/cache/multitask_profile_finetune/NR3C1-reddytime_multitask_profile_finetune
In [4]:
# Define constants
cluster_color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
default_cluster_color = "gray"
In [5]:
if tfm_heatmap_cache_dir:
    os.makedirs(tfm_heatmap_cache_dir, exist_ok=True)

Helper functions

For plotting and organizing things

In [6]:
def import_motifs(motif_files, group_names):
    """
    Imports a set of motifs from the saved HDF5 files.
    `group_names` is a list of group names, one for each motif file.
    Returns a list of motifs as L x 4 arrays, a parallel list of
    motif names, and a dictionary mapping group names to lists of
    motif names.
    """
    motifs, motif_names = [], []
    groups = {}
    for motif_file, stem in zip(motif_files, group_names):
        groups[stem] = []
        with h5py.File(motif_file, "r") as f:
            for key in f.keys():
                motif_name = "%s:%s" % (stem, key)
                motif_names.append(motif_name)
                motifs.append(f[key]["cwm_trimmed"][:])
                groups[stem].append(motif_name)
    return motifs, motif_names, groups
In [7]:
def motif_similarity_score(motif_1, motif_2, average=False, align_to_longer=True):
    """
    Computes the motif similarity score between two motifs by
    the summed cosine similarity, maximized over all possible sliding
    windows. Also returns the index relative to the start of `motif_2`
    where `motif_1` should be placed to maximize this score.
    If `average` is True, then use average of similarity of overlap.
    If `align_to_longer` is True, always use the longer motif as the basis
    for the index computation (if tie use `motif_2`). Otherwise, always use
    `motif_2`.
    """
    # Normalize
    motif_1 = motif_1 - np.mean(motif_1, axis=1, keepdims=True)
    motif_2 = motif_2 - np.mean(motif_2, axis=1, keepdims=True)
    motif_1 = motif_1 / np.sqrt(np.sum(motif_1 * motif_1, axis=1, keepdims=True))
    motif_2 = motif_2 / np.sqrt(np.sum(motif_2 * motif_2, axis=1, keepdims=True))
    
    # Always make motif_2 longer
    if align_to_longer and len(motif_1) > len(motif_2):
        motif_1, motif_2 = motif_2, motif_1
    
    # Pad motif_2 by len(motif_1) - 1 on either side
    orig_motif_2_len = len(motif_2)
    pad_size = len(motif_1) - 1
    motif_2 = np.pad(motif_2, ((pad_size, pad_size), (0, 0)))
    
    if average:
        # Compute overlap sizes
        overlap_sizes = np.empty(orig_motif_2_len + pad_size)
        overlap_sizes[:pad_size] = np.arange(1, len(motif_1))
        overlap_sizes[-pad_size:] = np.flip(np.arange(1, len(motif_1)))
        overlap_sizes[pad_size:-pad_size] = len(motif_1)
    
    # Compute similarities across all sliding windows
    scores = np.empty(orig_motif_2_len + pad_size)
    for i in range(orig_motif_2_len + pad_size):
        scores[i] = np.sum(motif_1 * motif_2[i : i + len(motif_1)])
    if average:
        scores = scores / overlap_sizes
    return np.max(scores), np.argmax(scores) - pad_size
In [8]:
def compute_similarity_matrix(motifs, show_progress=True):
    """
    Computes a similarity matrix over the pairs of motifs using cross
    correlation. `motifs` is a list of N motifs, where each is an L x 4
    array (may be different Ls).
    Returns an N x N array of distances.
    """
    num_motifs = len(motifs)
    sim_matrix = np.empty((num_motifs, num_motifs))
    t_iter = tqdm.notebook.trange(num_motifs) if show_progress else range(num_motifs)
    for i in t_iter:
        for j in range(i, num_motifs):
            sim, _ = motif_similarity_score(motifs[i], motifs[j])
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim
    return sim_matrix
In [9]:
def compute_clusters(linkage, goal_clusters, tolerance=(-2, 2), start=50, max_iter=10):
    """
    From a linkage map, computes clusters with a goal of `goal_clusters` clusters.
    Will allow the given tolerance. `start` is what distance threshold to check first.
    `max_iter` is the maximum number of checks to do.
    Returns the clustering.
    """
    clusters = scipy.cluster.hierarchy.fcluster(
        linkage, start, criterion="distance"
    )
    
    num_clusters = len(np.unique(clusters))
    if num_clusters > goal_clusters + tolerance[1] and max_iter:
        return compute_clusters(linkage, goal_clusters, tolerance, start * 2, max_iter - 1)
    elif num_clusters < goal_clusters - tolerance[0] and max_iter:
        return compute_clusters(linkage, goal_clusters, tolerance, start / 2, max_iter - 1)
    else:
        return clusters
In [10]:
def plot_heatmap(sim_matrix, labels, linkage, clusters):
    """
    Given a similariy matrix and labels, plots a heatmap with
    dendrogram. `linkage` is the linkage map computed on the matrix, and
    `clusters` is the cluster ID of each entry.
    Returns the figure and the indices in which to order the entries.
    """
    fig, ax = plt.subplots(
        nrows=2, ncols=2, figsize=(20, 20),
        gridspec_kw={
            "width_ratios": [20, 1],
            "height_ratios": [1, 4],
            "hspace": 0,
            "wspace": 0.1
        }
    )
    
    # Compute the color of every link based on cluster assignments
    # Adapted from https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func
    leaf_colors = [cluster_color_cycle[i % len(cluster_color_cycle)] for i in clusters]
    link_colors = {}
    for i, i_link in enumerate(linkage[:, :2].astype(int)):
        color_0 = link_colors[i_link[0]] if i_link[0] > len(linkage) else leaf_colors[i_link[0]]
        color_1 = link_colors[i_link[1]] if i_link[1] > len(linkage) else leaf_colors[i_link[1]]
        link_colors[i + 1 + len(linkage)] = color_0 if color_0 == color_1 else default_cluster_color

    dend = scipy.cluster.hierarchy.dendrogram(
        linkage, ax=ax[0, 0], link_color_func=(lambda x: link_colors[x])
    )

    order_inds = dend["leaves"]
    sim_matrix_reordered = sim_matrix[:, order_inds][order_inds, :]
    heatmap = ax[1, 0].imshow(sim_matrix_reordered, aspect="auto", cmap="Blues")
    ax[1, 0].set_yticks([])
    ax[1, 0].set_xticks(range(len(sim_matrix)))
    ax[1, 0].set_xticklabels(np.array(labels)[order_inds], rotation=90, fontsize=10)

    fig.colorbar(heatmap, cax=ax[1, 1])

    ax[0, 0].axis("off")
    ax[0, 1].axis("off")
    
    plt.show()
    return fig, order_inds
In [11]:
def aggregate_motifs(motifs):
    """
    Aggregates a list of L x 4 (not all the same L) motifs into a single
    L x 4 motif.
    """
    # Compute similarity matrix
    sim_matrix = compute_similarity_matrix(motifs, show_progress=False)

    # Sort motifs by how similar it is to everyone else
    inds = np.flip(np.argsort(np.sum(sim_matrix, axis=0)))
    
    # Have the consensus start with the most similar
    consensus = np.zeros_like(motifs[inds[0]])
    consensus = consensus + motifs[inds[0]]
    
    # For each successive motif, add it into the consensus
    for i in inds[1:]:
        motif = motifs[i]
        _, index = motif_similarity_score(motif, consensus, align_to_longer=False)
        if index >= 0:
            start, end = index, index + len(motif)
            consensus[start:end] = consensus[start:end] + motif[:len(consensus) - index]
        else:
            end = len(motif) + index
            consensus[:end] = consensus[:end] + motif[-index:-index + len(consensus)]
    return consensus / len(motifs)

Show motifs clusters

For all of the aggregated motifs, show the motif clusters.

In [12]:
motifs, motif_names, motif_groups = import_motifs(motif_files, group_names)
In [13]:
# Flip all motifs to be the purine-rich version
for i, motif in enumerate(motifs):
    if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
        motifs[i] = np.flip(motif)
In [14]:
# Compute similarity matrix
sim_matrix = compute_similarity_matrix(motifs)

In [15]:
# Compute linkage
linkage = scipy.cluster.hierarchy.linkage(sim_matrix, method="ward")
In [16]:
# Compute clusters
expected_clusters = np.max([len(m) for m in motif_groups.values()])
clusters = compute_clusters(linkage, expected_clusters, max_iter=100)
In [17]:
display(vdomh.h4("Number of motifs: %d" % len(clusters)))
display(vdomh.h4("Number of clusters: %d" % len(np.unique(clusters))))

Number of motifs: 17

Number of clusters: 11

In [18]:
# Plot heatmap
fig, order_inds = plot_heatmap(sim_matrix, motif_names, linkage, clusters)
if tfm_heatmap_cache_dir:
    fig.savefig(os.path.join(tfm_heatmap_cache_dir, "motif_cluster_heatmap.png"))
In [19]:
# Show aggregated and constituent motifs for each cluster
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "50%"}),
    vdomh.col(style={"width": "50%"})
)

header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Aggregate motif", style={"text-align": "center"}),
        vdomh.th("Constituent motifs", style={"text-align": "center"})
    )
)

all_consensus = {}
cluster_ids, counts = np.unique(clusters, return_counts=True)
cluster_ids = cluster_ids[np.flip(np.argsort(counts))]
for i, cluster_id in enumerate(cluster_ids):
    match_inds = np.where(clusters == cluster_id)[0]
    matches = [motifs[j] for j in match_inds]
    match_names = [motif_names[j] for j in match_inds]
    
    consensus = aggregate_motifs(matches)
    all_consensus[cluster_id] = consensus
    
    display(vdomh.h3("Cluster %d (%d/%d)" % (cluster_id, i + 1, len(cluster_ids))))
    display(vdomh.h4("%d motifs" % len(matches)))
    
    agg_fig = viz_sequence.plot_weights(consensus, figsize=(20, 4), return_fig=True)
    agg_fig.tight_layout()
    const_figs = []
    for motif, motif_name in zip(matches, match_names):
        fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)
        plt.title(motif_name)
        fig.tight_layout()
        const_figs.append(figure_to_vdom_image(fig))

    body = vdomh.tbody(vdomh.tr(vdomh.td(figure_to_vdom_image(agg_fig)), vdomh.td(*const_figs)))
    display(vdomh.table(colgroup, header, body))
    
    if tfm_heatmap_cache_dir:
        agg_fig.savefig(os.path.join(tfm_heatmap_cache_dir, "cluster_%d_aggregate_motif.png" % cluster_id))
        
    plt.close("all")

if tfm_heatmap_cache_dir:
    with h5py.File(os.path.join(tfm_heatmap_cache_dir, "motif_clusters.h5"), "w") as f:
        f.create_dataset("motif_names", data=np.array(motif_names).astype("S"), compression="gzip")
        all_motifs = f.create_group("all_motifs")
        for name, motif in zip(motif_names, motifs):
            all_motifs.create_dataset(name, data=motif, compression="gzip")
        f.create_dataset("similarity_matrix", data=sim_matrix, compression="gzip")
        f.create_dataset("dendrogram_order", data=order_inds, compression="gzip")
        f.create_dataset("clusters", data=clusters, compression="gzip")
        agg_motifs = f.create_group("cluster_motifs")
        for cluster_id, consensus in all_consensus.items():
            agg_motifs.create_dataset(str(cluster_id), data=consensus, compression="gzip")

Cluster 10 (1/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 9 (2/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 8 (3/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 4 (4/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 2 (5/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 1 (6/11)

2 motifs

Aggregate motifConstituent motifs

Cluster 11 (7/11)

1 motifs

Aggregate motifConstituent motifs

Cluster 7 (8/11)

1 motifs

Aggregate motifConstituent motifs

Cluster 6 (9/11)

1 motifs

Aggregate motifConstituent motifs

Cluster 5 (10/11)

1 motifs

Aggregate motifConstituent motifs

Cluster 3 (11/11)

1 motifs

Aggregate motifConstituent motifs