import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
import util
import motif.read_motifs as read_motifs
import plot.viz_sequence as viz_sequence
from feature.util import one_hot_to_seq
import h5py
import numpy as np
import scipy.signal
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import matplotlib.patches as patches
import tqdm
tqdm.tqdm_notebook(range(1))
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:17: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
<tqdm.notebook.tqdm_notebook at 0x7f70f7b73490>
# Plotting defaults
font_manager.fontManager.ttflist.extend(
font_manager.createFontList(
font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
)
)
plot_params = {
"figure.titlesize": 22,
"axes.titlesize": 22,
"axes.labelsize": 20,
"legend.fontsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"font.family": "Roboto",
"font.weight": "bold",
"svg.fonttype": "none"
}
plt.rcParams.update(plot_params)
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarning: The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead. after removing the cwd from sys.path.
out_path = "/users/amtseng/tfmodisco/figures/motif_hit_comparison"
os.makedirs(out_path, exist_ok=True)
tf_names = [
"E2F6", "FOXA2", "SPI1", "CEBPB", "MAX", "GABPA", "MAFK", "JUND", "NR3C1-reddytime", "REST"
]
tf_names_clean = tf_names[:]
tf_names_clean[8] = "NR3C1"
tf_num_tasks = {
"E2F6": 2,
"FOXA2": 4,
"SPI1": 4,
"CEBPB": 7,
"MAX": 7,
"GABPA": 9,
"MAFK": 9,
"JUND": 14,
"NR3C1-reddytime": 16,
"REST": 20
}
tf_best_model_types = {
"E2F6": list("MM"),
"FOXA2": list("SSMM"),
"SPI1": list("MSSS"),
"CEBPB": list("MMMMSMM"),
"MAX": list("MMSMMSS"),
"GABPA": list("MMMSMMMMM"),
"MAFK": list("MMMMMMMMM"),
"JUND": list("SMMSMSSSSSSSMS"),
"NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
"REST": list("MMMMMMMMMSMMSMMSMMMM")
}
def get_tfmodisco_motifs_path(tf_name):
"""
Gets the path to the cached TF-MoDISco motifs for a particular TF.
"""
return "/users/amtseng/tfmodisco/results/motifs/tfmodisco/%s_tfmodisco_motifs.h5" % tf_name
def get_tfm_motif_hits_path(tf_name, model_type, task_index):
"""
Gets the path to the cached results directory for TF-MoDISco hits
for the _profile_ head, using the given TF name, the model type
("S" or "M"), and a task index.
"""
assert model_type in ("M", "S")
path_match = lambda path: path.startswith(tf_name + "_") and "task%d_" % task_index in path and path.endswith("_profile")
base_dir = "/users/amtseng/tfmodisco/results/reports/motif_hits/cache/tfm"
subdir = os.path.join(base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
results_dir = [
path for path, _, _ in os.walk(subdir) if path_match(os.path.basename(path))
]
assert len(results_dir) == 1
return results_dir[0]
def get_moods_motif_hits_path(tf_name, model_type, task_index):
"""
Gets the path to the MOODS hits for the _profile_ head, using
the given TF name, the model type ("S" or "M"), and a task index.
"""
assert model_type in ("M", "S")
path_match = lambda path: path.startswith(tf_name + "_") and "task%d_" % task_index in path and path.endswith("_profile")
base_dir = "/users/amtseng/tfmodisco/results/moods"
subdir = os.path.join(base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
results_dir = [
path for path, _, _ in os.walk(subdir) if path_match(os.path.basename(path))
]
assert len(results_dir) == 1
return results_dir[0]
def get_predictions_impscores_path(tf_name, model_type, task_index):
"""
Gets the path to the predictions and importance scores HDF5s, using
the given TF name, the model type ("S" or "M"), and a task index.
"""
assert model_type in ("M", "S")
name_match = lambda name, task: name.startswith(tf_name + "_") and (task or "task%d_" % task_index in name) and name.endswith(".h5")
preds_base_dir = "/users/amtseng/tfmodisco/results/peak_predictions"
scores_base_dir = "/users/amtseng/tfmodisco/results/importance_scores"
preds_subdir = os.path.join(preds_base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
scores_subdir = os.path.join(scores_base_dir, "multitask_profile_finetune" if model_type == "M" else "singletask_profile_finetune")
preds_path = None
for path, _, names in os.walk(preds_subdir):
for name in names:
if name_match(name, model_type == "M"):
assert preds_path is None
preds_path = os.path.join(path, name)
scores_path = None
for path, _, names in os.walk(scores_subdir):
for name in names:
if name_match(name, False):
assert scores_path is None
scores_path = os.path.join(path, name)
return preds_path, scores_path
def import_tfm_motif_hits(tf_name, model_type, task_index):
"""
From the given TF name, the model type ("S" or "M"), and a task
index, imports the set of filtered TF-MoDISco motif hits, from the
_profile_ head.
"""
results_dir = get_tfm_motif_hits_path(tf_name, model_type, task_index)
hits_path = os.path.join(results_dir, "filtered_hits.tsv")
return pd.read_csv(hits_path, sep="\t", header=0, index_col=False)
def import_moods_motif_hits(tf_name, model_type, task_index):
"""
From the given TF name, the model type ("S" or "M"), and a task
index, imports the set of unfiltered MOODS motif hits, from the
_profile_ head.
"""
results_dir = get_moods_motif_hits_path(tf_name, model_type, task_index)
hits_path = os.path.join(results_dir, "moods_filtered_scored_collapsed-all.bed")
return pd.read_csv(
hits_path, sep="\t", header=0, index_col=False, names=[
"chrom", "start", "end", "key", "strand", "score", "peak_index",
"imp_total_score", "imp_frac_score", "imp_ic_avg_score"
]
)
def import_tfm_motif_hits_per_peak(tf_name, model_type, task_index):
"""
From the given TF name, the model type ("S" or "M"), and a task
index, imports the set of number of TF-MoDIsco motif hits per peak,
from the _profile_ head. Uses the hit type specified.
"""
results_dir = get_tfm_motif_hits_path(tf_name, model_type, task_index)
peak_matched_hits_path = os.path.join(results_dir, "peak_matched_hits.tsv")
num_hits_per_peak = []
with open(peak_matched_hits_path, "r") as f:
next(f) # Header
for line in f:
tokens = line.split("\t")
if not tokens[1].strip():
num_hits_per_peak.append(0)
else:
num_hits_per_peak.append(tokens[1].count(",") + 1)
return np.array(num_hits_per_peak)
def import_moods_motif_hits_per_peak(tf_name, model_type, task_index):
"""
From the given TF name, the model type ("S" or "M"), and a task
index, imports the set of number of MOODS motif hits per peak,
from the _profile_ head. Uses the hit type specified.
"""
hit_table = import_moods_motif_hits(tf_name, model_type, task_index)
peak_inds, counts = np.unique(hit_table["peak_index"], return_counts=True)
num_hits_per_peak = np.zeros(np.max(peak_inds) + 1)
num_hits_per_peak[peak_inds] = counts
return num_hits_per_peak
def import_all_motif_hits_per_peak(hit_type, keys=None):
"""
Imports all motif hit counts for all TFs and tasks, for the
given hit type. If specified, limits to given keys.
Returns a dictionary mapping tuple (TF name, task index) to arrays of
hits per peak.
"""
if not keys:
keys = [
(tf_name, task_index) for tf_name in tf_names for task_index in range(tf_num_tasks[tf_name])
]
result = {}
for key in tqdm.notebook.tqdm(keys):
tf_name, task_index = key
model_type = tf_best_model_types[tf_name][task_index]
if hit_type == "tfm":
result[key] = import_tfm_motif_hits_per_peak(tf_name, model_type, task_index)
else:
result[key] = import_moods_motif_hits_per_peak(tf_name, model_type, task_index)
return result
def import_all_tfmodisco_motifs(keys=None):
"""
Imports all TF-MoDISco motifs as a dictionary of CWMs, hCWMs, and PFMs,
for all TFs and tasks. If specified, limits to given keys.
Returns a dictionary mapping tuple (TF name, task index) to dictionary
mapping motif key to the motif, of the _profile_ head.
"""
if not keys:
keys = [
(tf_name, task_index) for tf_name in tf_names for task_index in range(tf_num_tasks[tf_name])
]
cwms, hcwms, pfms = {}, {}, {}
for key in tqdm.notebook.tqdm(keys):
tf_name, task_index = key
model_type = tf_best_model_types[tf_name][task_index]
motif_file_path = get_tfmodisco_motifs_path(tf_name)
with h5py.File(motif_file_path, "r") as f:
if model_type == "M":
group = f["multitask_finetune"]
else:
group = f["singletask_finetune"]
motif_dset = group["task_%d" % task_index]["profile"]
cwms[key], hcwms[key], pfms[key] = {}, {}, {}
for motif_key in motif_dset.keys():
cwm = motif_dset[motif_key]["cwm_full"][:]
hcwm = motif_dset[motif_key]["hcwm_full"][:]
pfm = motif_dset[motif_key]["pfm_full"][:]
cwm = read_motifs.trim_motif_by_ic(pfm, cwm)
hcwm = read_motifs.trim_motif_by_ic(pfm, hcwm)
pfm = read_motifs.trim_motif_by_ic(pfm, pfm)
cwms[key][motif_key] = cwm
hcwms[key][motif_key] = hcwm
pfms[key][motif_key] = pfm
return cwms, hcwms, pfms
def create_box_plot(ax, dist_list, colors, sample=100):
"""
Creates a box plot on the given instantiated axes.
`dist_list` is a list of vectors. `colors` is a parallel
list of colors for each set of points.
Only plots `sample` points, which may be 0. If negative,
do not sample.
"""
plot_parts = ax.boxplot(dist_list, showfliers=False, widths=(0.8 / len(dist_list)), zorder=0)
for part in ["boxes"]:
for i in range(len(dist_list)):
plot_parts[part][i].set_color(colors[i])
for part in ["caps", "whiskers"]:
for i in range(len(dist_list)):
# Each plot has 2 whiskers
plot_parts[part][2 * i].set_color(colors[i])
plot_parts[part][(2 * i) + 1].set_color(colors[i])
for part in ["medians"]:
for i in range(len(dist_list)):
plot_parts[part][i].set_color("black")
for i, vals in enumerate(dist_list):
if sample != 0:
if sample > 0:
vals = np.random.choice(vals, size=min(sample, len(vals)), replace=False)
x = np.random.normal(i + 1, 0.04, len(vals))
ax.scatter(x, vals, alpha=0.5, color=colors[i], zorder=1)
# Import TF-MoDISco motifs
tfm_cwms, tfm_hcwms, tfm_pfms = import_all_tfmodisco_motifs()
# Import TF-MoDISco hit counts
cond_keys = [
(tf_name, task_index) for tf_name in tf_names for task_index in range(tf_num_tasks[tf_name])
]
tfm_hit_counts = import_all_motif_hits_per_peak("tfm")
moods_hit_counts = import_all_motif_hits_per_peak("moods")
# All cell types separately
fig, ax = plt.subplots(
ncols=len(cond_keys), sharey=True, figsize=(len(cond_keys), 6)
)
# Create violins
for i, key in enumerate(cond_keys):
create_box_plot(
ax[i],
[tfm_hit_counts[key], moods_hit_counts[key]], ["cornflowerblue", "lightcoral"],
sample=0
)
ax[i].set_xticks([]) # Remove x-axis labels, as they don't mean much
ax[i].set_xlabel("%s task %d" % key, rotation=90)
ax[0].set_ylabel("Number of motifs per peak")
fig.suptitle("Motifs found per peak")
plt.savefig(
os.path.join(out_path, "motifs_per_peak_separate_cell_types.svg"),
format="svg"
)
plt.show()
# Collapse cell types
tfm_hit_counts_collapsed, moods_hit_counts_collapsed = {}, {}
for tf_name in tf_names:
tf_cond_keys = [key for key in cond_keys if key[0] == tf_name]
tfm_hit_counts_collapsed[tf_name] = np.concatenate([tfm_hit_counts[key] for key in tf_cond_keys])
moods_hit_counts_collapsed[tf_name] = np.concatenate([moods_hit_counts[key] for key in tf_cond_keys])
fig, ax = plt.subplots(
ncols=len(tf_names), sharey=True, figsize=(len(tf_names) * 1.5, 6)
)
# Create violins
for i, tf_name in enumerate(tf_names):
create_box_plot(
ax[i],
[tfm_hit_counts_collapsed[tf_name], moods_hit_counts_collapsed[tf_name]], ["cornflowerblue", "lightcoral"],
sample=0
)
ax[i].set_xticks([]) # Remove x-axis labels, as they don't mean much
ax[i].set_xlabel(tf_names_clean[i], rotation=90)
ax[0].set_ylabel("Number of motifs per peak")
fig.suptitle("Motifs found per peak")
plt.savefig(
os.path.join(out_path, "motifs_per_peak_collapsed_cell_types.svg"),
format="svg"
)
plt.show()
def pwm_score_track(pwm, one_hot_seq):
"""
Computes a PWM log-odds score track by sliding the PWM across the
one-hot encoded sequence.
Returns an N-array of length equal to `one_hot_seq` (padding is
performed).
"""
return scipy.signal.correlate(one_hot_seq, pwm, mode="same")[:, 2]
def plot_example_hits(
chrom, start, end, tfm_hits_table, moods_hits_table, profiles_hdf5_path,
imp_scores_hdf5_path, task_index, pfm_dict, prof_center_size=700, score_center_size=100,
hyp_score_key="profile_hyp_scores", save_path=None
):
"""
For a given region, plots the true/predicted profiles, importance scores, and the
TFM and MOODS hits at that region on each.
"""
mid = (start + end) // 2
prof_start = mid - (prof_center_size // 2)
prof_end = prof_start + prof_center_size
score_start = mid - (score_center_size // 2)
score_end = score_start + score_center_size
with h5py.File(profiles_hdf5_path, "r") as f:
# Need to use the coordinates of the profiles themselves
prof_len = f["predictions"]["log_pred_profs"].shape[2]
prof_coords_chrom = f["coords"]["coords_chrom"][:].astype(str)
prof_coords_start = f["coords"]["coords_start"][:]
prof_coords_end = f["coords"]["coords_end"][:]
mid = (prof_coords_start + prof_coords_end) // 2
prof_coords_start = mid - (prof_len // 2)
prof_coords_end = prof_coords_start + prof_len
match_inds = np.where(
(prof_coords_chrom == chrom) &
(prof_coords_start <= prof_start) &
(prof_coords_end >= prof_end)
)[0]
if not match_inds.size:
print("Warning: did not find sufficiently large prediction track for %s:%d-%s" % (chrom, prof_start, prof_end))
return
match_ind = match_inds[0]
coord_start, coord_end = prof_coords_start[match_ind], prof_coords_end[match_ind]
cut_start = prof_start - coord_start
cut_end = cut_start + prof_center_size
if f["predictions"]["log_pred_profs"].shape[1] == 1:
task_index = 0
pred_profs = np.exp(f["predictions"]["log_pred_profs"][match_ind][task_index][cut_start:cut_end])
true_profs = f["predictions"]["true_profs"][match_ind][task_index][cut_start:cut_end]
with h5py.File(imp_scores_hdf5_path, "r") as f:
match_inds = np.where(
(f["coords_chrom"][:].astype(str) == chrom) &
(f["coords_start"][:] <= score_start) &
(f["coords_end"][:] >= score_end)
)[0]
if not match_inds.size:
print("Warning: did not find sufficiently large importance score track for %s:%d-%s" % (chrom, score_start, score_end))
return
match_ind = match_inds[0]
coord_start, coord_end = f["coords_start"][match_ind], f["coords_end"][match_ind]
hyp_scores = f[hyp_score_key][match_ind]
one_hot_seq = f["input_seqs"][match_ind]
cut_start = score_start - coord_start
cut_end = cut_start + score_center_size
hyp_scores = hyp_scores[cut_start:cut_end]
one_hot_seq = one_hot_seq[cut_start:cut_end]
tfm_hits = tfm_hits_table[
(tfm_hits_table["chrom"] == chrom) &
(tfm_hits_table["start"] < end) &
(tfm_hits_table["end"] > start)
]
moods_hits = moods_hits_table[
(moods_hits_table["chrom"] == chrom) &
(moods_hits_table["start"] < end) &
(moods_hits_table["end"] > start)
]
prof_fig, ax = plt.subplots(nrows=3, sharex=True, figsize=(20, 8))
# Draw profiles
ax[0].plot(true_profs[:, 0], color="darkslateblue")
ax[0].plot(-true_profs[:, 1], color="darkorange")
ax[0].set_title("True ChIP-seq profiles")
ax[1].plot(pred_profs[:, 0], color="darkslateblue")
ax[1].plot(-pred_profs[:, 1], color="darkorange")
ax[1].set_title("Predicted ChIP-seq profiles")
# Draw motif hits
for i, (hit_table, color) in enumerate([(tfm_hits, "blue"), (moods_hits, "red")]):
for _, row in hit_table.iterrows():
start_pos = max(row["start"] - prof_start, 0)
end_pos = min(row["end"] - prof_start, prof_center_size)
ax[2].add_patch(patches.Rectangle(
xy=(start_pos, i * 0.5), width=(end_pos - start_pos), height=0.5, color=color, fill=False
))
ax[2].text(
x=((end_pos + start_pos) // 2), y=((i * 0.5) + 0.25), s=row["key"]
)
# Draw vertical lines that denote the portion with importance scores
for i in range(2):
ax[i].axvline(score_start - prof_start, color="gray")
ax[i].axvline(score_end - prof_start, color="gray")
if save_path:
plt.savefig(
os.path.join(save_path + "_profiles.svg"), format="svg"
)
plt.show()
motif_keys = np.unique(np.concatenate([tfm_hits["key"], moods_hits["key"]]))
fig, ax = plt.subplots(nrows=len(motif_keys), figsize=(20, 2 * len(motif_keys)))
if len(motif_keys) == 1:
ax = [ax]
for i, motif_key in enumerate(motif_keys):
pfm = pfm_dict[motif_key]
pwm = np.log2((pfm + 0.0001) / read_motifs.BACKGROUND_FREQS[None])
match_track = pwm_score_track(pwm, one_hot_seq)
match_track[match_track < 0] = 0
ax[i].plot(match_track)
ax[i].set_title(motif_key)
fig.suptitle("PWM log-odds scores across the sequence")
fig.tight_layout()
if save_path:
plt.savefig(
os.path.join(save_path + "_pwm_logodds.svg"), format="svg"
)
plt.show()
highlights = {}
for i, (hit_table, color) in enumerate([(tfm_hits, "blue"), (moods_hits, "red")]):
highlights[color] = []
for _, row in hit_table.iterrows():
if row["start"] >= score_end or row["end"] <= score_start:
continue
start_pos = max(row["start"] - score_start, 0)
end_pos = min(row["end"] - score_start, score_center_size)
highlights[color].append((start_pos, end_pos))
print(row["key"], start_pos, end_pos)
score_fig = viz_sequence.plot_weights(
hyp_scores * one_hot_seq, figsize=(20, 4), subticks_frequency=score_center_size, highlight=highlights, return_fig=True
)
score_fig.tight_layout()
if save_path:
plt.savefig(
os.path.join(save_path + "_impscores.svg"), format="svg"
)
plt.show()
print(one_hot_to_seq(one_hot_seq))
tf_name = "E2F6"
task_index = 0
model_type = tf_best_model_types[tf_name][task_index]
tfm_hits_table = import_tfm_motif_hits(tf_name, model_type, task_index)
moods_hits_table = import_moods_motif_hits(tf_name, model_type, task_index)
profiles_hdf5_path, imp_scores_hdf5_path = get_predictions_impscores_path(tf_name, model_type, task_index)
pfm_dict = tfm_pfms[(tf_name, task_index)]
# Get a set of peak indices that match the desired criteria
all_peak_inds = np.unique(tfm_hits_table[tfm_hits_table["key"] == "0_0"]["peak_index"]) # Require this motif
passed_peak_inds = []
for peak_ind in tqdm.notebook.tqdm(all_peak_inds):
tfm_hits = tfm_hits_table[tfm_hits_table["peak_index"] == peak_ind]
moods_hits = moods_hits_table[moods_hits_table["peak_index"] == peak_ind]
if len(tfm_hits) < 3 or len(tfm_hits) > 6:
continue
if len(moods_hits) == 0 or len(moods_hits) > len(tfm_hits) - 2:
continue
passed_peak_inds.append(peak_ind)
print("Found %d peaks that pass criteria" % len(passed_peak_inds))
num_to_take = min(20, len(passed_peak_inds))
seed = 20211207
rng = np.random.RandomState(seed)
for peak_ind in rng.choice(passed_peak_inds, size=num_to_take, replace=False):
pass
for peak_ind in [12462, 8287]:
tfm_hits = tfm_hits_table[tfm_hits_table["peak_index"] == peak_ind]
if tfm_hits.empty:
continue
chrom = tfm_hits["chrom"].values[0]
start = np.min(tfm_hits["start"])
end = np.max(tfm_hits["end"])
print("%s:%d-%d (index %d)" % (chrom, start, end, peak_ind))
if peak_ind in (12462, 8287):
save_path = os.path.join(out_path, "motif_hit_example_%s_task%d_peak%d" % (tf_name, task_index, peak_ind))
plot_example_hits(
chrom, start, end, tfm_hits_table, moods_hits_table, profiles_hdf5_path,
imp_scores_hdf5_path, task_index, pfm_dict, save_path=save_path
)
else:
plot_example_hits(
chrom, start, end, tfm_hits_table, moods_hits_table, profiles_hdf5_path,
imp_scores_hdf5_path, task_index, pfm_dict
)
print("")
Found 3915 peaks that pass criteria chrX:54440400-54440450 (index 12462)
0_2 67 75 0_1 48 56 0_0 25 32 0_0 24 32
CAGAGGGATTGGGCCGCCGGGACGTCACGTGGACTGGGGCCGGATAATGGCGGGCGCTGCAGAAGATGCGCGAGCTCTTTTCCGGGCTGGGGTCTGCGCG chr8:11651635-11651674 (index 8287)
0_2 31 39 0_5 59 70 0_1 43 51 0_1 43 51
TTCAGCACAGCAGCCCCTCCCTCTGTGCACACACCGCTAGGACTTCCCGCCTCCACACCCCGCCACGTGGATGGATACACCCATGTGTGGACATACACAC