import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/notebooks/reports/"))
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
import motif.read_motifs as read_motifs
import plot.viz_sequence as viz_sequence
from util import motif_similarity_score, purine_rich_motif
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import vdom.helpers as vdomh
from IPython.display import display
# Plotting defaults
font_manager.fontManager.ttflist.extend(
font_manager.createFontList(
font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
)
)
plot_params = {
"figure.titlesize": 22,
"axes.titlesize": 22,
"axes.labelsize": 20,
"legend.fontsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"font.family": "Roboto",
"font.weight": "bold",
"svg.fonttype": "none"
}
plt.rcParams.update(plot_params)
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarning: The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead. after removing the cwd from sys.path.
if "TFM_TF_NAME" in os.environ:
tf_name = os.environ["TFM_TF_NAME"]
else:
tf_name = "MAX"
out_path = "/users/amtseng/tfmodisco/figures/motif_prevalence/motif_prevalence_%s/" % tf_name
os.makedirs(out_path, exist_ok=True)
tf_num_tasks = {
"E2F6": 2,
"FOXA2": 4,
"SPI1": 4,
"CEBPB": 7,
"MAX": 7,
"GABPA": 9,
"MAFK": 9,
"JUND": 14,
"NR3C1-reddytime": 16,
"REST": 20
}
tf_best_model_types = {
"E2F6": list("MM"),
"FOXA2": list("SSMM"),
"SPI1": list("MSSS"),
"CEBPB": list("MMMMSMM"),
"MAX": list("MMSMMSS"),
"GABPA": list("MMMSMMMMM"),
"MAFK": list("MMMMMMMMM"),
"JUND": list("SMMSMSSSSSSSMS"),
"NR3C1-reddytime": list("MMMSMMSMMMMSMMMM"),
"REST": list("MMMMMMMMMSMMSMMSMMMM")
}
num_tasks = tf_num_tasks[tf_name]
best_model_types = tf_best_model_types[tf_name]
tfm_motif_file = "/users/amtseng/tfmodisco/results/motifs/tfmodisco/%s_tfmodisco_cpmerged_motifs.h5" % tf_name
multitask_finetune_model_def_tsv = "/users/amtseng/tfmodisco/results/model_stats/multitask_profile_finetune_stats.tsv"
singletask_finetune_model_def_tsv = "/users/amtseng/tfmodisco/results/model_stats/singletask_profile_finetune_stats.tsv"
def get_motif_hit_paths():
"""
Returns a list of pairs, where each pair is the count and profile
motif hit paths for the task.
"""
# First, import the best fold definitions
# Finetuned multi-task model
best_mt_fold = None
with open(multitask_finetune_model_def_tsv, "r") as f:
for line in f:
tokens = line.strip().split("\t")
if tokens[0] == tf_name and int(tokens[1]) == num_tasks - 1:
assert best_mt_fold is None
best_mt_fold = int(tokens[2])
# Finetuned single-task models
best_st_folds = []
with open(singletask_finetune_model_def_tsv, "r") as f:
for line in f:
tokens = line.strip().split("\t")
if tokens[0] == tf_name:
best_st_folds.append(int(tokens[2]))
assert len(best_st_folds) == num_tasks
# Get paths to motif hits
task_motif_hit_paths = []
base_path = "/users/amtseng/tfmodisco/results/reports/motif_hits/cache/tfm"
for task_index, model_type in enumerate(best_model_types):
if model_type == "M":
path = os.path.join(
base_path,
"multitask_profile_finetune",
"%s_multitask_profile_finetune_fold%d" % (tf_name, best_mt_fold),
"%s_multitask_profile_finetune_task%d_fold%d_{0}" % (tf_name, task_index, best_mt_fold),
"filtered_hits.tsv"
)
else:
path = os.path.join(
base_path,
"singletask_profile_finetune",
"%s_singletask_profile_finetune_fold%d" % (tf_name, best_st_folds[task_index]),
"task_%d" % task_index,
"%s_singletask_profile_finetune_task%d_fold%d_{0}" % (tf_name, task_index, best_st_folds[task_index]),
"filtered_hits.tsv"
)
task_motif_hit_paths.append(
(path.format("count"), path.format("profile"))
)
return task_motif_hit_paths
def import_tfmodisco_motifs(motif_file, model_types, motif_type="cwm_trimmed"):
"""
From a file containing all motifs for that TF, imports the
trimmed CWMs (or another kind of motif type) of the fine-tuned models
corresponding to the model type for each task.
Returns a list of dictionaries (one for each task), where
each dictionary maps motif key to motif.
"""
motifs = []
with h5py.File(motif_file, "r") as f:
mtft = f["multitask_finetune"]
stft = f["singletask_finetune"]
for i, model_type in enumerate(model_types):
task = "task_%d" % i
if model_type == "M":
dset = mtft[task]
else:
dset = stft[task]
task_motifs = {}
for motif_key in dset.keys():
if "0_" in motif_key:
# Motifs that are (or are constructed from) positive metacluster only
task_motifs["T%d:%s" % (i, motif_key)] = dset[motif_key][motif_type][:]
motifs.append(task_motifs)
return motifs
tfm_cwm_motifs = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "cwm_trimmed")
tfm_pfm_motifs = import_tfmodisco_motifs(tfm_motif_file, best_model_types, "pfm_trimmed")
# For easier viewing/clustering, flip all motifs to purine-rich orientation
# Note that this is not a perfect process, so automated clustering may be imperfect with
# respect to orientation. Final aggregate motifs are done in a reverse-complement-sensitive
# manner to fix this
# For TF-MoDISco motifs, make sure we flip the CWM and PFM to match
for cwm_motif_dict, pfm_motif_dict in zip(tfm_cwm_motifs, tfm_pfm_motifs):
for key in list(cwm_motif_dict.keys()):
cwm = purine_rich_motif(cwm_motif_dict[key])
cwm_motif_dict[key] = cwm # Flip CWM to purine-rich orientation
pfm = pfm_motif_dict[key]
pwm = read_motifs.pfm_to_pwm(pfm)
# Flip PFM if its PWM should be flipped to better match the CWM
score = motif_similarity_score(cwm, pwm, mean_normalize=False)
rev_score = motif_similarity_score(cwm, np.flip(pwm, axis=(0, 1)), mean_normalize=False)
if rev_score > score:
pfm_motif_dict[key] = np.flip(pfm_motif_dict[key], axis=(0, 1))
/users/amtseng/tfmodisco/notebooks/reports/util.py:103: RuntimeWarning: invalid value encountered in true_divide motif_2 = motif_2 / np.sqrt(np.sum(motif_2 * motif_2, axis=1, keepdims=True))
For each motif, extract the prevalence (by task) in the peaks.
# Import the motif hits for each task
import_motif_hits = lambda hits_path: pd.read_csv(hits_path, sep="\t", header=0, index_col=False)
task_motif_hit_paths = get_motif_hit_paths()
task_motif_hits = []
for count_path, profile_path in task_motif_hit_paths:
count_table = import_motif_hits(count_path)[["key", "peak_index"]]
profile_table = import_motif_hits(profile_path)[["key", "peak_index"]]
# We only need the key and peak index
task_motif_hits.append({"C": count_table, "P": profile_table})
def get_hit_prevalence(hit_table, motif_keys):
"""
Computes the motif prevalence from the hit table, as the number of hits
given to that motif.
"""
return len(hit_table[np.isin(hit_table["key"], motif_keys)])
# Obtain set of motif prevalences
motif_prevalences = [{} for _ in range(num_tasks)]
for task_index, motif_dict in enumerate(tfm_cwm_motifs):
for key in motif_dict.keys():
motif_keys = {}
tokens = key.split(":")[1:] # Remove Tx
# May be compound key
for token in tokens:
head, motif_key = token[0], token[1:]
try:
motif_keys[head].append(motif_key)
except KeyError:
motif_keys[head] = [motif_key]
# Compute prevalence over the motif keys, taking the average over the count/profile heads
motif_prevalences[task_index][key] = np.mean([
get_hit_prevalence(task_motif_hits[task_index][head], motif_keys[head])
for head in motif_keys.keys()
]) if motif_keys else 0
# Normalize number of hits to get proportions
total = sum(motif_prevalences[task_index].values())
motif_prevalences[task_index] = {k : v / total for k, v in motif_prevalences[task_index].items()}
for task_index, prev_dict in enumerate(motif_prevalences):
fig, ax = plt.subplots(figsize=(20, 6))
keys = sorted(list(prev_dict.keys()), key=lambda k: -prev_dict[k])
prevs = [prev_dict[key] for key in keys]
ax.bar(keys, prevs)
ax.set_xticklabels(keys, rotation=90)
ax.set_ylabel("Proportion of motif instances")
ax.set_title("Task %d motif prevalences" % task_index)
plt.savefig(
os.path.join(out_path, "%s_task%d_motif_prevalences.svg" % (tf_name, task_index)),
format="svg"
)
plt.show()
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator
# Save the prevalences
out_hdf5 = os.path.join(out_path, "%s_motif_prevalences.h5" % tf_name)
with h5py.File(out_hdf5, "w") as f:
for task_index, prev_dict in enumerate(motif_prevalences):
task_group = f.create_group("task_%d" % task_index)
keys = list(prev_dict.keys())
task_group.create_dataset("motif_keys", data=np.array(keys).astype("S"))
task_group.create_dataset("motif_prevalences", data=np.array([prev_dict[key] for key in keys]))
# Show the motifs for each task
for task_index, motif_dict in enumerate(tfm_cwm_motifs):
display(vdomh.h3("Task %d" % task_index))
for key, cwm in motif_dict.items():
display(vdomh.h3(key))
fig = viz_sequence.plot_weights(cwm, subticks_frequency=100, return_fig=True)
fig.tight_layout()
plt.show()