import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import scipy.signal
import scipy.cluster.hierarchy
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
# Plotting defaults
font_manager.fontManager.ttflist.extend(
font_manager.createFontList(
font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
)
)
plot_params = {
"figure.titlesize": 22,
"axes.titlesize": 22,
"axes.labelsize": 20,
"legend.fontsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"font.family": "Roboto",
"font.weight": "bold"
}
plt.rcParams.update(plot_params)
# Define parameters/fetch arguments
motif_files = os.environ["TFM_MOTIF_FILES"].split(",")
group_names = os.environ["TFM_GROUP_NAMES"].split(",")
if "TFM_HEATMAP_CACHE" in os.environ:
tfm_heatmap_cache_dir = os.environ["TFM_HEATMAP_CACHE"]
else:
tfm_heatmap_cache_dir = None
assert len(motif_files) == len(group_names)
assert len(group_names) == len(set(group_names))
print("Motif files: %s" % motif_files)
print("Group names: %s" % group_names)
print("Saved heatmap cache: %s" % tfm_heatmap_cache_dir)
# Define constants
cluster_color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
default_cluster_color = "gray"
if tfm_heatmap_cache_dir:
os.makedirs(tfm_heatmap_cache_dir, exist_ok=True)
For plotting and organizing things
def import_motifs(motif_files, group_names):
"""
Imports a set of motifs from the saved HDF5 files.
`group_names` is a list of group names, one for each motif file.
Returns a list of motifs as L x 4 arrays, a parallel list of
motif names, and a dictionary mapping group names to lists of
motif names.
"""
motifs, motif_names = [], []
groups = {}
for motif_file, stem in zip(motif_files, group_names):
groups[stem] = []
with h5py.File(motif_file, "r") as f:
for key in f.keys():
motif_name = "%s:%s" % (stem, key)
motif_names.append(motif_name)
motifs.append(f[key]["cwm_trimmed"][:])
groups[stem].append(motif_name)
return motifs, motif_names, groups
def motif_similarity_score(motif_1, motif_2, average=False, align_to_longer=True):
"""
Computes the motif similarity score between two motifs by
the summed cosine similarity, maximized over all possible sliding
windows. Also returns the index relative to the start of `motif_2`
where `motif_1` should be placed to maximize this score.
If `average` is True, then use average of similarity of overlap.
If `align_to_longer` is True, always use the longer motif as the basis
for the index computation (if tie use `motif_2`). Otherwise, always use
`motif_2`.
"""
# Normalize
motif_1 = motif_1 - np.mean(motif_1, axis=1, keepdims=True)
motif_2 = motif_2 - np.mean(motif_2, axis=1, keepdims=True)
motif_1 = motif_1 / np.sqrt(np.sum(motif_1 * motif_1, axis=1, keepdims=True))
motif_2 = motif_2 / np.sqrt(np.sum(motif_2 * motif_2, axis=1, keepdims=True))
# Always make motif_2 longer
if align_to_longer and len(motif_1) > len(motif_2):
motif_1, motif_2 = motif_2, motif_1
# Pad motif_2 by len(motif_1) - 1 on either side
orig_motif_2_len = len(motif_2)
pad_size = len(motif_1) - 1
motif_2 = np.pad(motif_2, ((pad_size, pad_size), (0, 0)))
if average:
# Compute overlap sizes
overlap_sizes = np.empty(orig_motif_2_len + pad_size)
overlap_sizes[:pad_size] = np.arange(1, len(motif_1))
overlap_sizes[-pad_size:] = np.flip(np.arange(1, len(motif_1)))
overlap_sizes[pad_size:-pad_size] = len(motif_1)
# Compute similarities across all sliding windows
scores = np.empty(orig_motif_2_len + pad_size)
for i in range(orig_motif_2_len + pad_size):
scores[i] = np.sum(motif_1 * motif_2[i : i + len(motif_1)])
if average:
scores = scores / overlap_sizes
return np.max(scores), np.argmax(scores) - pad_size
def compute_similarity_matrix(motifs, show_progress=True):
"""
Computes a similarity matrix over the pairs of motifs using cross
correlation. `motifs` is a list of N motifs, where each is an L x 4
array (may be different Ls).
Returns an N x N array of distances.
"""
num_motifs = len(motifs)
sim_matrix = np.empty((num_motifs, num_motifs))
t_iter = tqdm.notebook.trange(num_motifs) if show_progress else range(num_motifs)
for i in t_iter:
for j in range(i, num_motifs):
sim, _ = motif_similarity_score(motifs[i], motifs[j])
sim_matrix[i, j] = sim
sim_matrix[j, i] = sim
return sim_matrix
def compute_clusters(linkage, goal_clusters, tolerance=(-2, 2), start=50, max_iter=10):
"""
From a linkage map, computes clusters with a goal of `goal_clusters` clusters.
Will allow the given tolerance. `start` is what distance threshold to check first.
`max_iter` is the maximum number of checks to do.
Returns the clustering.
"""
clusters = scipy.cluster.hierarchy.fcluster(
linkage, start, criterion="distance"
)
num_clusters = len(np.unique(clusters))
if num_clusters > goal_clusters + tolerance[1] and max_iter:
return compute_clusters(linkage, goal_clusters, tolerance, start * 2, max_iter - 1)
elif num_clusters < goal_clusters - tolerance[0] and max_iter:
return compute_clusters(linkage, goal_clusters, tolerance, start / 2, max_iter - 1)
else:
return clusters
def plot_heatmap(sim_matrix, labels, linkage, clusters):
"""
Given a similariy matrix and labels, plots a heatmap with
dendrogram. `linkage` is the linkage map computed on the matrix, and
`clusters` is the cluster ID of each entry.
Returns the figure and the indices in which to order the entries.
"""
fig, ax = plt.subplots(
nrows=2, ncols=2, figsize=(20, 20),
gridspec_kw={
"width_ratios": [20, 1],
"height_ratios": [1, 4],
"hspace": 0,
"wspace": 0.1
}
)
# Compute the color of every link based on cluster assignments
# Adapted from https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func
leaf_colors = [cluster_color_cycle[i % len(cluster_color_cycle)] for i in clusters]
link_colors = {}
for i, i_link in enumerate(linkage[:, :2].astype(int)):
color_0 = link_colors[i_link[0]] if i_link[0] > len(linkage) else leaf_colors[i_link[0]]
color_1 = link_colors[i_link[1]] if i_link[1] > len(linkage) else leaf_colors[i_link[1]]
link_colors[i + 1 + len(linkage)] = color_0 if color_0 == color_1 else default_cluster_color
dend = scipy.cluster.hierarchy.dendrogram(
linkage, ax=ax[0, 0], link_color_func=(lambda x: link_colors[x])
)
order_inds = dend["leaves"]
sim_matrix_reordered = sim_matrix[:, order_inds][order_inds, :]
heatmap = ax[1, 0].imshow(sim_matrix_reordered, aspect="auto", cmap="Blues")
ax[1, 0].set_yticks([])
ax[1, 0].set_xticks(range(len(sim_matrix)))
ax[1, 0].set_xticklabels(np.array(labels)[order_inds], rotation=90, fontsize=10)
fig.colorbar(heatmap, cax=ax[1, 1])
ax[0, 0].axis("off")
ax[0, 1].axis("off")
plt.show()
return fig, order_inds
def aggregate_motifs(motifs):
"""
Aggregates a list of L x 4 (not all the same L) motifs into a single
L x 4 motif.
"""
# Compute similarity matrix
sim_matrix = compute_similarity_matrix(motifs, show_progress=False)
# Sort motifs by how similar it is to everyone else
inds = np.flip(np.argsort(np.sum(sim_matrix, axis=0)))
# Have the consensus start with the most similar
consensus = np.zeros_like(motifs[inds[0]])
consensus = consensus + motifs[inds[0]]
# For each successive motif, add it into the consensus
for i in inds[1:]:
motif = motifs[i]
_, index = motif_similarity_score(motif, consensus, align_to_longer=False)
if index >= 0:
start, end = index, index + len(motif)
consensus[start:end] = consensus[start:end] + motif[:len(consensus) - index]
else:
end = len(motif) + index
consensus[:end] = consensus[:end] + motif[-index:-index + len(consensus)]
return consensus / len(motifs)
motifs, motif_names, motif_groups = import_motifs(motif_files, group_names)
# Flip all motifs to be the purine-rich version
for i, motif in enumerate(motifs):
if np.sum(motif[:, [0, 2]]) < 0.5 * np.sum(motif):
motifs[i] = np.flip(motif)
# Compute similarity matrix
sim_matrix = compute_similarity_matrix(motifs)
# Compute linkage
linkage = scipy.cluster.hierarchy.linkage(sim_matrix, method="ward")
# Compute clusters
expected_clusters = np.max([len(m) for m in motif_groups.values()])
clusters = compute_clusters(linkage, expected_clusters, max_iter=100)
display(vdomh.h4("Number of motifs: %d" % len(clusters)))
display(vdomh.h4("Number of clusters: %d" % len(np.unique(clusters))))
# Plot heatmap
fig, order_inds = plot_heatmap(sim_matrix, motif_names, linkage, clusters)
if tfm_heatmap_cache_dir:
fig.savefig(os.path.join(tfm_heatmap_cache_dir, "motif_cluster_heatmap.png"))
# Show aggregated and constituent motifs for each cluster
colgroup = vdomh.colgroup(
vdomh.col(style={"width": "50%"}),
vdomh.col(style={"width": "50%"})
)
header = vdomh.thead(
vdomh.tr(
vdomh.th("Aggregate motif", style={"text-align": "center"}),
vdomh.th("Constituent motifs", style={"text-align": "center"})
)
)
all_consensus = {}
cluster_ids, counts = np.unique(clusters, return_counts=True)
cluster_ids = cluster_ids[np.flip(np.argsort(counts))]
for i, cluster_id in enumerate(cluster_ids):
match_inds = np.where(clusters == cluster_id)[0]
matches = [motifs[j] for j in match_inds]
match_names = [motif_names[j] for j in match_inds]
consensus = aggregate_motifs(matches)
all_consensus[cluster_id] = consensus
display(vdomh.h3("Cluster %d (%d/%d)" % (cluster_id, i + 1, len(cluster_ids))))
display(vdomh.h4("%d motifs" % len(matches)))
agg_fig = viz_sequence.plot_weights(consensus, figsize=(20, 4), return_fig=True)
agg_fig.tight_layout()
const_figs = []
for motif, motif_name in zip(matches, match_names):
fig = viz_sequence.plot_weights(motif, figsize=(20, 4), return_fig=True)
plt.title(motif_name)
fig.tight_layout()
const_figs.append(figure_to_vdom_image(fig))
body = vdomh.tbody(vdomh.tr(vdomh.td(figure_to_vdom_image(agg_fig)), vdomh.td(*const_figs)))
display(vdomh.table(colgroup, header, body))
if tfm_heatmap_cache_dir:
agg_fig.savefig(os.path.join(tfm_heatmap_cache_dir, "cluster_%d_aggregate_motif.png" % cluster_id))
plt.close("all")
if tfm_heatmap_cache_dir:
with h5py.File(os.path.join(tfm_heatmap_cache_dir, "motif_clusters.h5"), "w") as f:
f.create_dataset("motif_names", data=np.array(motif_names).astype("S"), compression="gzip")
all_motifs = f.create_group("all_motifs")
for name, motif in zip(motif_names, motifs):
all_motifs.create_dataset(name, data=motif, compression="gzip")
f.create_dataset("similarity_matrix", data=sim_matrix, compression="gzip")
f.create_dataset("dendrogram_order", data=order_inds, compression="gzip")
f.create_dataset("clusters", data=clusters, compression="gzip")
agg_motifs = f.create_group("cluster_motifs")
for cluster_id, consensus in all_consensus.items():
agg_motifs.create_dataset(str(cluster_id), data=consensus, compression="gzip")