import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from motif.read_motifs import pfm_to_pwm
from util import figure_to_vdom_image, purine_rich_motif
import plot.viz_sequence as viz_sequence
import numpy as np
import h5py
import pandas as pd
import sklearn.cluster
import scipy.cluster.hierarchy
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import vdom.helpers as vdomh
from IPython.display import display
# Plotting defaults
font_manager.fontManager.ttflist.extend(
font_manager.createFontList(
font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
)
)
plot_params = {
"figure.titlesize": 22,
"axes.titlesize": 22,
"axes.labelsize": 20,
"legend.fontsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"font.family": "Roboto",
"font.weight": "bold"
}
plt.rcParams.update(plot_params)
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarning: The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead. after removing the cwd from sys.path.
# Define parameters/fetch arguments
tfm_results_cache_dir = os.environ["TFM_RESULTS_CACHE_DIR"]
motif_hits_cache_dir = os.environ["TFM_MOTIF_HITS_CACHE_DIR"]
if "TFM_MOTIF_KEYS" in os.environ:
motif_keys = os.environ["TFM_MOTIF_KEYS"].split(",")
else:
motif_keys = None
print("Saved TF-MoDISco results cache: %s" % tfm_results_cache_dir)
print("Saved motif hits cache: %s" % motif_hits_cache_dir)
Saved TF-MoDISco results cache: /users/amtseng/tfmodisco/results/reports/tfmodisco_results/cache/multitask_profile_finetune/REST_multitask_profile_finetune_fold7/REST_multitask_profile_finetune_task0_fold7_count Saved motif hits cache: /users/amtseng/tfmodisco/results/reports/motif_hits/cache/tfm/multitask_profile_finetune/REST_multitask_profile_finetune_fold7/REST_multitask_profile_finetune_task0_fold7_count
motif_file = os.path.join(tfm_results_cache_dir, "all_motifs.h5")
if not motif_keys:
with h5py.File(motif_file, "r") as f:
motif_keys = sorted(f.keys(), key=lambda k: (int(k.split("_")[0]), int(k.split("_")[1])))
motif_keys = [key for key in motif_keys if key.startswith("0_")]
def plot_profiles(seqlet_true_profs, seqlet_pred_profs, kmeans_clusters=5, save_path=None):
"""
Plots the given profiles with a heatmap.
Arguments:
`seqlet_true_profs`: an N x O x 2 NumPy array of true profiles, either as raw
counts or probabilities (they will be normalized)
`seqlet_pred_profs`: an N x O x 2 NumPy array of predicted profiles, either as
raw counts or probabilities (they will be normalized)
`kmeans_cluster`: when displaying profile heatmaps, there will be this
many clusters
`save_path`: if provided, save the profile matrices here
Returns the figure.
"""
assert len(seqlet_true_profs.shape) == 3
assert seqlet_true_profs.shape == seqlet_pred_profs.shape
num_profs, width, _ = seqlet_true_profs.shape
# First, normalize the profiles along the output profile dimension
def normalize(arr, axis=0):
arr_sum = np.sum(arr, axis=axis, keepdims=True)
arr_sum[arr_sum == 0] = 1 # If 0, keep 0 as the quotient instead of dividing by 0
return arr / arr_sum
true_profs_norm = normalize(seqlet_true_profs, axis=1)
pred_profs_norm = normalize(seqlet_pred_profs, axis=1)
# Compute the mean profiles across all examples
true_profs_mean = np.mean(true_profs_norm, axis=0)
pred_profs_mean = np.mean(pred_profs_norm, axis=0)
# Perform k-means clustering on the predicted profiles, with the strands pooled
kmeans_clusters = max(5, num_profs // 50) # Set number of clusters based on number of profiles, with minimum
kmeans = sklearn.cluster.KMeans(n_clusters=kmeans_clusters)
cluster_assignments = kmeans.fit_predict(
np.reshape(pred_profs_norm, (pred_profs_norm.shape[0], -1))
)
# Perform hierarchical clustering on the cluster centers to determine optimal ordering
kmeans_centers = kmeans.cluster_centers_
cluster_order = scipy.cluster.hierarchy.leaves_list(
scipy.cluster.hierarchy.optimal_leaf_ordering(
scipy.cluster.hierarchy.linkage(kmeans_centers, method="centroid"), kmeans_centers
)
)
# Order the profiles so that the cluster assignments follow the optimal ordering
cluster_inds = []
for cluster_id in cluster_order:
cluster_inds.append(np.where(cluster_assignments == cluster_id)[0])
cluster_inds = np.concatenate(cluster_inds)
# Compute a matrix of profiles, normalized to the maximum height, ordered by clusters
def make_profile_matrix(flat_profs, order_inds):
matrix = flat_profs[order_inds]
maxes = np.max(matrix, axis=1, keepdims=True)
maxes[maxes == 0] = 1 # If 0, keep 0 as the quotient instead of dividing by 0
return matrix / maxes
true_matrix = make_profile_matrix(true_profs_norm, cluster_inds)
pred_matrix = make_profile_matrix(pred_profs_norm, cluster_inds)
if save_path:
np.savez_compressed(
true_profs_mean=true_profs_mean, pred_profs_mean=pred_profs_mean,
true_matrix=true_matrix, pred_matrix=pred_matrix
)
# Create a figure with the right dimensions
mean_height = 4
heatmap_height = min(num_profs * 0.004, 8)
fig_height = mean_height + (2 * heatmap_height)
fig, ax = plt.subplots(
3, 2, figsize=(16, fig_height), sharex=True,
gridspec_kw={
"width_ratios": [1, 1],
"height_ratios": [mean_height / fig_height, heatmap_height / fig_height, heatmap_height / fig_height]
}
)
# Plot the average predictions
ax[0, 0].plot(true_profs_mean[:, 0], color="darkslateblue")
ax[0, 0].plot(-true_profs_mean[:, 1], color="darkorange")
ax[0, 1].plot(pred_profs_mean[:, 0], color="darkslateblue")
ax[0, 1].plot(-pred_profs_mean[:, 1], color="darkorange")
# Set axes on average predictions
max_mean_val = max(np.max(true_profs_mean), np.max(pred_profs_mean))
mean_ylim = max_mean_val * 1.05 # Make 5% higher
ax[0, 0].set_title("True profiles")
ax[0, 0].set_ylabel("Average probability")
ax[0, 1].set_title("Predicted profiles")
for j in (0, 1):
ax[0, j].set_ylim(-mean_ylim, mean_ylim)
ax[0, j].label_outer()
# Plot the heatmaps
ax[1, 0].imshow(true_matrix[:, :, 0], interpolation="nearest", aspect="auto", cmap="Blues")
ax[1, 1].imshow(pred_matrix[:, :, 0], interpolation="nearest", aspect="auto", cmap="Blues")
ax[2, 0].imshow(true_matrix[:, :, 1], interpolation="nearest", aspect="auto", cmap="Oranges")
ax[2, 1].imshow(pred_matrix[:, :, 1], interpolation="nearest", aspect="auto", cmap="Oranges")
# Set axes on heatmaps
for i in (1, 2):
for j in (0, 1):
ax[i, j].set_yticks([])
ax[i, j].set_yticklabels([])
ax[i, j].label_outer()
width = true_matrix.shape[1]
delta = 100
num_deltas = (width // 2) // delta
labels = list(range(max(-width // 2, -num_deltas * delta), min(width // 2, num_deltas * delta) + 1, delta))
tick_locs = [label + max(width // 2, num_deltas * delta) for label in labels]
for j in (0, 1):
ax[2, j].set_xticks(tick_locs)
ax[2, j].set_xticklabels(labels)
ax[2, j].set_xlabel("Distance from seqlet center (bp)")
fig.tight_layout()
plt.show()
return fig
def plot_summit_dists(summit_dists):
"""
Plots the distribution of seqlet distances to summits.
Arguments:
`summit_dists`: the array of distances as returned by
`get_summit_distances`
Returns the figure.
"""
fig = plt.figure(figsize=(8, 6))
num_bins = max(len(summit_dists) // 30, 20)
plt.hist(summit_dists, bins=num_bins, color="purple")
plt.title("Histogram of distance of seqlets to peak summits")
plt.xlabel("Signed distance from seqlet center to nearest peak summit (bp)")
plt.show()
return fig
def cluster_matrix_indices(matrix, num_clusters):
"""
Clusters matrix using k-means. Always clusters on the first
axis. Returns the indices needed to optimally order the matrix
by clusters.
"""
if len(matrix) == 1:
# Don't cluster at all
return np.array([0])
num_clusters = min(num_clusters, len(matrix))
# Perform k-means clustering
kmeans = sklearn.cluster.MiniBatchKMeans(n_clusters=num_clusters)
cluster_assignments = kmeans.fit_predict(matrix)
# Perform hierarchical clustering on the cluster centers to determine optimal ordering
kmeans_centers = kmeans.cluster_centers_
cluster_order = scipy.cluster.hierarchy.leaves_list(
scipy.cluster.hierarchy.optimal_leaf_ordering(
scipy.cluster.hierarchy.linkage(kmeans_centers, method="centroid"), kmeans_centers
)
)
# Order the peaks so that the cluster assignments follow the optimal ordering
cluster_inds = []
for cluster_id in cluster_order:
cluster_inds.append(np.where(cluster_assignments == cluster_id)[0])
cluster_inds = np.concatenate(cluster_inds)
return cluster_inds
def create_violin_plot(ax, dist_list, colors):
"""
Creates a violin plot on the given instantiated axes.
`dist_list` is a list of vectors. `colors` is a parallel
list of colors for each violin.
"""
num_perfs = len(dist_list)
q1, med, q3 = np.stack([
np.nanpercentile(data, [25, 50, 70], axis=0) for data in dist_list
], axis=1)
iqr = q3 - q1
lower_outlier = q1 - (1.5 * iqr)
upper_outlier = q3 + (1.5 * iqr)
sorted_clipped_data = [ # Remove outliers based on outlier rule
np.sort(vec[(vec >= lower_outlier[i]) & (vec <= upper_outlier[i])])
for i, vec in enumerate(dist_list)
]
plot_parts = ax.violinplot(
sorted_clipped_data, showmeans=False, showmedians=False, showextrema=False
)
violin_parts = plot_parts["bodies"]
for i in range(num_perfs):
violin_parts[i].set_facecolor(colors[i])
violin_parts[i].set_edgecolor(colors[i])
violin_parts[i].set_alpha(0.7)
inds = np.arange(1, num_perfs + 1)
ax.vlines(inds, q1, q3, color="black", linewidth=5, zorder=1)
ax.scatter(inds, med, marker="o", color="white", s=30, zorder=2)
Motifs are trimmed based on information content, and presented in descending order by number of supporting seqlets. The motifs are separated by metacluster. The motifs are presented as PWMs, CWMs, and eCWMs. We show the forward orientation, which is defined as the orientation that is richer in purines.
colgroup = vdomh.colgroup(
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "30%"}),
vdomh.col(style={"width": "30%"}),
vdomh.col(style={"width": "30%"})
)
header = vdomh.thead(
vdomh.tr(
vdomh.th("ID", style={"text-align": "center"}),
vdomh.th("Seqlets", style={"text-align": "center"}),
vdomh.th("PWM", style={"text-align": "center"}),
vdomh.th("CWM", style={"text-align": "center"}),
vdomh.th("eCWM", style={"text-align": "center"})
)
)
body = []
motif_file = os.path.join(tfm_results_cache_dir, "all_motifs.h5")
with h5py.File(motif_file, "r") as f:
for motif_key in motif_keys:
pfm, cwm, hcwm = f[motif_key]["pfm_trimmed"][:], f[motif_key]["cwm_trimmed"][:], f[motif_key]["hcwm_trimmed"][:]
pwm = pfm_to_pwm(pfm)
if np.sum(pwm[:, [0, 2]]) > 0.5 * np.sum(pwm):
# Forward is purine-rich, reverse-complement is pyrimidine-rich
pass
else:
pwm, cwm, hcwm = np.flip(pwm), np.flip(cwm), np.flip(hcwm)
pwm_fig = viz_sequence.plot_weights(pwm, figsize=(20, 4), return_fig=True)
pwm_fig.tight_layout()
cwm_fig = viz_sequence.plot_weights(cwm, figsize=(20, 4), return_fig=True)
cwm_fig.tight_layout()
hcwm_fig = viz_sequence.plot_weights(hcwm, figsize=(20, 4), return_fig=True)
hcwm_fig.tight_layout()
seqlets_file = os.path.join(tfm_results_cache_dir, "%s_seqlets.npz" % motif_key)
with np.load(seqlets_file) as g:
num_seqlets = len(g["seqlet_seqs"])
body.append(
vdomh.tr(
vdomh.td(motif_key),
vdomh.td(str(num_seqlets)),
vdomh.td(figure_to_vdom_image(pwm_fig)),
vdomh.td(figure_to_vdom_image(cwm_fig)),
vdomh.td(figure_to_vdom_image(hcwm_fig))
)
)
display(vdomh.table(colgroup, header, vdomh.tbody(*body)))
plt.close("all")
/users/amtseng/tfmodisco/src/plot/viz_sequence.py:152: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). fig = plt.figure(figsize=figsize)
ID | Seqlets | PWM | CWM | eCWM |
---|---|---|---|---|
0_0 | 2821 | |||
0_1 | 743 | |||
0_2 | 198 | |||
0_3 | 137 | |||
0_4 | 118 | |||
0_6 | 59 | |||
0_7 | 52 | |||
0_8 | 45 | |||
0_9 | 41 |
For each motif, we show the binding footprint as the set of observed and model-predicted profiles surrounding instances of the motif. We also show the distribution of distances between instances of the motif and the nearest called peak summit. For clarity, we reproduce the CWM as shown above for each motif.
with h5py.File(motif_file, "r") as f:
for motif_key in motif_keys:
display(vdomh.h4("Motif %s" % motif_key))
cwm = f[motif_key]["cwm_trimmed"][:]
viz_sequence.plot_weights(purine_rich_motif(cwm), figsize=(10, 2), return_fig=True)
plt.show()
seqlets_file = os.path.join(tfm_results_cache_dir, "%s_seqlets.npz" % motif_key)
with np.load(seqlets_file) as g:
seqlet_true_profs, seqlet_pred_profs = g["seqlet_true_profs"], g["seqlet_pred_profs"]
plot_profiles(
# Flatten to NT x O x 2
np.reshape(seqlet_true_profs, (-1, seqlet_true_profs.shape[2], seqlet_true_profs.shape[3])),
np.reshape(seqlet_pred_profs, (-1, seqlet_pred_profs.shape[2], seqlet_pred_profs.shape[3]))
)
plot_summit_dists(g["summit_dists"])
plt.show()
We show a cumulative distribution of how many motif instances are found per peak. We also show a bar plot of how many instances of each motif were found across all peaks.
hits_path = os.path.join(motif_hits_cache_dir, "filtered_hits.tsv")
peak_hits_path = os.path.join(motif_hits_cache_dir, "peak_matched_hits.tsv")
with h5py.File(motif_file, "r") as f:
all_motif_keys = sorted(f.keys())
hit_table = pd.read_csv(hits_path, sep="\t", header=0, index_col=0)
peak_hit_table = pd.read_csv(peak_hits_path, sep="\t", header=0, index_col=False)
peak_hit_table["filtered_hit_indices"] = peak_hit_table["filtered_hit_indices"].astype(str)
motif_key_to_motif_index = {all_motif_keys[i] : i for i in range(len(all_motif_keys))}
hit_table["motif_index"] = hit_table["key"].apply(lambda k: motif_key_to_motif_index[k]).values
# Construct N x M array for N peaks and M motifs, holding the counts of each motif in each peak
peak_hit_counts = np.zeros((len(peak_hit_table), len(all_motif_keys)), dtype=int)
for _, row in peak_hit_table.iterrows():
peak_ind, hit_inds = row["peak_index"], row["filtered_hit_indices"]
if hit_inds != "nan":
hit_inds = np.array([int(x) for x in hit_inds.split(",")])
for hit_ind in hit_inds:
peak_hit_counts[peak_ind, hit_table.loc[hit_ind]["motif_index"]] += 1
# Limit the peak hit counts to only the motifs we care about
keep_inds = np.array([i for i in range(len(all_motif_keys)) if all_motif_keys[i] in motif_keys])
peak_hit_counts = peak_hit_counts[:, keep_inds]
# Number of motif hits per peak
motifs_per_peak = np.sum(peak_hit_counts, axis=1)
fig, ax = plt.subplots(figsize=(10, 10))
bins = np.concatenate([np.arange(np.max(motifs_per_peak) + 1), [np.inf]])
ax.hist(motifs_per_peak, bins=bins, density=True, histtype="step", cumulative=True)
ax.set_title("Cumulative distribution of number of motif hits per peak")
ax.set_xlabel("Number of motifs k in peak")
ax.set_ylabel("Proportion of peaks with at least k motifs")
plt.show()
# Number of peaks with each motif
frac_peaks_with_motif = np.sum(peak_hit_counts > 0, axis=0) / len(peak_hit_counts)
labels = np.array(motif_keys)
sorted_inds = np.flip(np.argsort(frac_peaks_with_motif))
frac_peaks_with_motif = frac_peaks_with_motif[sorted_inds]
labels = labels[sorted_inds]
fig, ax = plt.subplots(figsize=(20, 8))
ax.bar(np.arange(len(labels)), frac_peaks_with_motif)
ax.set_title("Proportion of peaks with each motif")
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
plt.show()
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
We show the significance of motifs co-occurring with each other in peaks. This can two motifs that tend to co-occur each each other, or a single motif that tends to occur multiple times in single peaks. We show a heatmap of co-occurrence significance across all pairs of motifs. For significantly co-occurring motifs, we compute the distribution of distances between motifs in peaks.
cooccurrence_file_path = os.path.join(motif_hits_cache_dir, "cooccurrences.h5")
with h5py.File(cooccurrence_file_path, "r") as f:
pval_matrix = f["pvals"][:]
with h5py.File(motif_file, "r") as f:
all_motif_keys = sorted(f.keys())
# Limit matrix to the keys we want
keep_inds = np.array([i for i in range(len(pval_matrix)) if all_motif_keys[i] in motif_keys])
pval_matrix = pval_matrix[keep_inds][:, keep_inds]
# Cluster by p-value
num_motifs = len(pval_matrix)
inds = cluster_matrix_indices(pval_matrix, max(5, num_motifs // 4))
pval_matrix_ordered = pval_matrix[inds][:, inds]
motif_keys_ordered = np.array(motif_keys)[inds]
# Plot the p-value matrix
fig_width = max(5, num_motifs)
p_fig, ax = plt.subplots(figsize=(fig_width, fig_width))
# Replace 0s with minimum value (we'll label them properly later)
zero_mask = pval_matrix_ordered == 0
non_zeros = pval_matrix_ordered[~zero_mask]
if not len(non_zeros):
logpval_matrix = np.tile(np.inf, pval_matrix_ordered.shape)
else:
min_val = np.min(pval_matrix_ordered[~zero_mask])
pval_matrix_ordered[zero_mask] = min_val
logpval_matrix = -np.log10(pval_matrix_ordered)
hmap = ax.imshow(logpval_matrix[:num_motifs, :num_motifs])
ax.set_xticks(np.arange(num_motifs))
ax.set_yticks(np.arange(num_motifs))
ax.set_xticklabels(motif_keys_ordered[:num_motifs], rotation=45)
ax.set_yticklabels(motif_keys_ordered[:num_motifs])
# Loop over data dimensions and create text annotations.
for i in range(num_motifs):
for j in range(num_motifs):
if zero_mask[i, j]:
text = "Inf"
else:
text = "%.2f" % np.abs(logpval_matrix[i, j])
ax.text(j, i, text, ha="center", va="center")
p_fig.colorbar(hmap, orientation="horizontal")
ax.set_title("-log(p) significance of peaks with both motifs")
p_fig.tight_layout()
plt.show()
cooccurrence_dist_path = os.path.join(motif_hits_cache_dir, "intermotif_dists.h5")
if os.path.exists(cooccurrence_dist_path):
with h5py.File(cooccurrence_dist_path, "r") as f:
distance_dict = {}
for key in f.keys():
distance_dict[tuple(key.split(":"))] = f[key][:]
# Create the plot
fig, ax = plt.subplots(
nrows=len(motif_keys), ncols=len(motif_keys), figsize=(len(motif_keys) * 4, len(motif_keys) * 4)
)
if type(ax) is not np.ndarray:
ax = np.array([[ax]])
# Map motif key to axis index
key_to_index = dict(zip(motif_keys_ordered, np.arange(len(motif_keys_ordered))))
def clean_subplot(ax):
# Do this instead of ax.axis("off"), which would also remove any
# axis labels
ax.set_yticks([])
ax.set_xticks([])
for orient in ("top", "bottom", "left", "right"):
ax.spines[orient].set_visible(False)
# Create violins
for i in range(len(motif_keys)):
for j in range(i, len(motif_keys)):
key_1, key_2 = motif_keys_ordered[i], motif_keys_ordered[j]
key_pair, rev_key_pair = (key_1, key_2), (key_2, key_1)
axis_1, axis_2 = key_to_index[key_1], key_to_index[key_2]
# Always plot lower triangle
if axis_1 < axis_2:
axis_1, axis_2 = axis_2, axis_1
if key_pair in distance_dict or rev_key_pair in distance_dict:
if rev_key_pair in distance_dict:
key_pair = rev_key_pair
dist = distance_dict[key_pair]
create_violin_plot(ax[axis_1, axis_2], [dist], ["mediumorchid"])
ax[axis_1, axis_2].set_xticks([]) # Remove x-axis labels, as they don't mean much
if axis_1 != axis_2:
# If off diagonal, clean the axes of the symmetric cell
clean_subplot(ax[axis_2, axis_1])
else:
clean_subplot(ax[axis_1, axis_2])
clean_subplot(ax[axis_2, axis_1])
# Make motif labels
for i in range(len(motif_keys)):
ax[i, 0].set_ylabel(motif_keys_ordered[i])
ax[-1, i].set_xlabel(motif_keys_ordered[i])
# Remove x-axis labels/ticks
ax[-1, -1].set_xticks([])
fig.suptitle("Distance distributions between significantly co-occurring motifs")
fig.tight_layout(rect=[0, 0.03, 1, 0.98])
plt.show()