from basepair.imports import *
from basepair.samplers import top_sum_count, top_max_count
from basepair.plot.tracks import plot_tracks, filter_tracks
from kipoi.data_utils import get_dataset_item
model_dir = Path(f"{ddir}/processed//chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / "modisco/by_peak_tasks/weighted/Oct4"
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4']
mr.plot_pssm("metacluster_0", "pattern_0", trim_frac=0.08);
mr.plot_pssm("metacluster_0", "pattern_2", trim_frac=0.08);
pattern_name = 'metacluster_0/pattern_0'
pattern = mr.get_pattern(pattern_name)
pattern.plot(rotate_y=0);
def load_scores(h5_file, importance='weighted'):
gt = importance
im = HDF5Reader(h5_file)
im.open()
d = im.load_all()
seq = d['inputs']
tasks = list(d['grads'])
hyp_contrib = {f"{task}": mean(d['grads'][task][gt])
for task in tasks}
contrib = {f"{task}": hyp_contrib[f"{task}"] * seq
for task in tasks}
profile = {f"t/{task}": d['targets']['profile'][task] for task in tasks}
im.close()
return seq, contrib, hyp_contrib, profile
# subset
def get_scores(im, idx, tasks):
hyp_contrib = {task: (im.f[f'/grads/{task}/weighted/0'][idx] +
im.f[f'/grads/{task}/weighted/1'][idx]) / 2
for task in tasks}
seq = im.f['/inputs/'][idx]
contrib = {t: hyp_contrib[t] * seq for t in tasks}
profile = {t: im.f[f'/targets/profile/{t}'][idx] for t in tasks}
return seq, contrib, hyp_contrib, profile
seq, contrib, hyp_contrib, profile = load_scores(model_dir / "grad.valid.h5")
idx = top_max_count(profile['t/Oct4'][:, 400:600], 80)
seq.shape
example_idx = 11128
viz_dict = filter_tracks({**get_dataset_item(profile, example_idx),
**get_dataset_item(contrib, example_idx)}, xlim=[400, 600])
plot_tracks(viz_dict);
t = dict(contrib=contrib['Oct4'][example_idx, 400:600],
hyp_contrib=hyp_contrib['Oct4'][example_idx, 400:600],
seq=seq[example_idx, 400:600])
# target
plot_tracks(dict(contrib=t['contrib'], hyp_contrib=t['hyp_contrib']));
from basepair.modisco.results import *
pattern.trim_seq_ic(0.08).plot(['seq', 'contrib/Oct4', 'hyp_contrib/Oct4'],
rotate_y=0, height=1, letter_width=0.2);
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
]
cached = True
cache_dir = Path(Path(ddir) / "cache/dev/")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / "hit-scoring.hdf5"
dfp = pd.read_hdf(cache_file, "/scores")
!du -sh {cache_file}
#dfp.to_hdf(cache_file, "/scores")
if not cached:
dfl = []
match_l, importance_l, seq_match_l = [], [], []
for tf, pattern_name in pattern_names:
pattern = mr.get_pattern(pattern_name).trim_seq_ic(0.08)
match, importance = pattern.scan_importance(contrib, hyp_contrib, tasks,
n_jobs=20, verbose=True)
seq_match = pattern.scan_seq(seq, n_jobs=20, verbose=True)
dfm = pattern.get_instances(tasks, match, importance, seq_match, fdr=0.01, verbose=True)
dfm['tf'] = tf
dfl.append(dfm)
match_l.append(match)
importance_l.append(importance)
seq_match_l.append(seq_match)
dfp = pd.concat(dfl)
# cache
cached = True
dfp.to_hdf(cache_file, "/scores")
dfp['idx'] = np.arange(len(dfp))
dfp = dfp[dfp.score_seq_match > 0]
# get also the match scores
match = np.stack(match_l, axis=-1)
importance = np.stack(importance_l, axis=-1)
seq_match = np.stack(seq_match_l, axis=-1)
dfp.head()
# convert a wide format to a long format
def melt_prefix(dfp, var='imp'):
prefix = var + "/"
dfp_imp = dfp[['idx'] + [prefix + t for t in tasks]].melt(id_vars='idx', value_name=var)
dfp_imp['task'] = dfp_imp.variable.str.replace(prefix, "")
del dfp_imp['variable']
return dfp_imp
dfpm = melt_prefix(dfp, 'imp').merge(melt_prefix(dfp, 'match'), on=['idx', 'task'])
dfpl = dfp[[c for c in dfp.columns
if not c.startswith("imp/")
and not c.startswith("match/")]].merge(dfpm, on='idx')
dfpl['log_imp'] = np.log(1+dfpl.imp)
dfp['log_imp_weighted'] = np.log(1+dfp['imp_weighted'])
dfp['log_match_weighted'] = np.log(dfp['match_weighted'])
# partition the scores into 4 categories (according to the median)
dfp_medians = pd.DataFrame({'log_match_weighted': dfp.groupby("tf").log_match_weighted.median(),
'log_imp_weighted': dfp.groupby("tf").log_imp_weighted.median()}).reset_index()
dfp
log_imp_weighted_median = dfp.tf.map(dfp.groupby("tf").log_imp_weighted.median())
log_match_weighted_median = dfp.tf.map(dfp.groupby("tf").log_match_weighted.median())
dfp['imp_cat'] = (dfp.log_imp_weighted > log_imp_weighted_median).map({False: 'low', True: 'high'})
dfp['match_cat'] = (dfp.log_match_weighted > log_match_weighted_median).map({False: 'low', True: 'high'})
# dfp['match_max_task'] = dfp.match_max__task
# dfpl['match_max_task'] = dfpl.match_max__task
dfpl.head()
from plotnine import *
import plotnine
dfpl.shape
dfpl_subset = dfpl.sample(10000) # focus on the subset of the data
dfp_subset = dfp.sample(10000) # focus on the subset of the data
dfpl_subset.head()
import warnings
warnings.filterwarnings("ignore")
plotnine.options.figure_size = (12, 7)
ggplot(aes(x="match", y="log_imp"), dfpl_subset) + geom_point(alpha=0.1) + \
facet_grid("tf~task", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 7)
ggplot(aes(x="match_weighted", y="log_imp"), dfpl_subset) + geom_point(alpha=0.1) + \
facet_grid("tf~task", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 3)
ggplot(aes(x="log_match_weighted", y="log_imp_weighted"), dfp_subset) + geom_point(alpha=0.1) + \
facet_grid(".~tf", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 3)
ggplot(aes(x="score_seq_match", y="log_match_weighted"), dfp_subset) + geom_point(alpha=0.1) + \
facet_grid(".~tf", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 3)
ggplot(aes(x="score_seq_match", y="log_imp_weighted"), dfp_subset) + geom_point(alpha=0.1) + \
facet_grid(".~tf", labeller='label_both') + theme_bw()
ggplot(aes(x='log_match_weighted'), dfp_subset) + geom_histogram() + facet_grid(".~tf", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 3)
ggplot(aes(x="log_match_weighted", y="log_imp_weighted"), dfp_subset) + geom_point(alpha=0.1) + \
geom_vline(aes(xintercept='log_match_weighted'), data=dfp_medians, alpha=1, color='orange', linetype='dashed') + \
geom_hline(aes(yintercept='log_imp_weighted'), data=dfp_medians, alpha=1, color='orange', linetype='dashed') + \
facet_grid(".~tf", labeller='label_both') + theme_bw()
plotnine.options.figure_size = (12, 3)
ggplot(aes(x="log_match_weighted", y="log_imp_weighted"), dfp_subset) + geom_point(alpha=0.1) + \
geom_vline(aes(xintercept='log_match_weighted'), data=dfp_medians, alpha=1, color='orange', linetype='dashed') + \
geom_hline(aes(yintercept='log_imp_weighted'), data=dfp_medians, alpha=1, color='orange', linetype='dashed') + \
facet_grid(".~tf", labeller='label_both') + theme_bw()
dfp_subset.log_imp.weighted.median()
np.log(1+dfp.imp_max).plot.hist(200);
#plt.xlim([0, 20])
plt.xlabel("Maximum importance plot");
plotnine.options.figure_size = (7, 7)
ggplot(aes(x='tf', color='task', y='match'), dfpl) + geom_boxplot() + theme_bw()
plotnine.options.figure_size = (7, 7)
ggplot(aes(x='tf', color='task', y='log_imp'), dfpl) + geom_boxplot() + theme_bw()
dfpl.groupby(['tf', 'task']).imp.mean().unstack().plot.bar()
plt.ylabel("Average importance");
dfpl.groupby(['tf', 'task']).match.mean().unstack().plot.bar()
plt.title("Average match per task");
plt.ylabel("Average Match");
dfp.groupby("tf").imp_max_task.value_counts().unstack().plot.bar();
plt.ylabel("Frequency")
plt.title("Number of times a particular task has max importance ");
dfp.groupby("tf").match_max_task.value_counts().unstack().plot.bar();
plt.ylabel("Frequency")
plt.title("Number of times a particular task has max match ");
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
]
pattern.plot("contrib");
tasks
task_imp = np.array([np.abs(pattern.contrib[t]).mean() for t in tasks])
task_imp = task_imp / task_imp.sum()
pattern = mr.get_pattern(pattern.name).trim_seq_ic(0.08)
dfpl_subset[(dfpl_subset.task == 'Klf4') & (dfpl_subset.tf == 'Oct4-Sox2')& (dfpl_subset.imp > 40)]
viz_dict = filter_tracks({**get_dataset_item(profile, example_idx),
**get_dataset_item(contrib, example_idx)}, xlim=[400, 600])
example_idx = 1288
dfp_subset_idx = dfp[(dfp.example_idx == example_idx) & (dfp.pattern_start > 400) & (dfp.pattern_end < 600)]
dfp_subset_idx
np.log(1+dfp_subset_idx.imp_max)
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
# TODO - is this thing normalized by importance?
plot_tracks(filter_tracks({**get_dataset_item(profile, example_idx),
**get_dataset_item(contrib, example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);
dfp[(np.log(1+dfp.imp_max) > 2) & (dfp.pattern_start > 400) & (dfp.pattern_end < 600)].groupby('example_idx').size().plot.hist()
def prefix_dict(d, prefix):
return {prefix + d: v for d,v in d.items()}
example_idx = dfp.example_idx.sample(1).iloc[0]
print(example_idx)
dfp_subset_idx = dfp[#(np.log(1+dfp.imp_max) > 2) &
(dfp.imp_cat == 'high') &
(dfp.match_cat == 'high') &
#(dfp.tf == "Errb") &
(dfp.example_idx == example_idx) &
(dfp.pattern_start > 400) &
(dfp.pattern_end < 600)]
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(prefix_dict(hyp_contrib, "h/"), example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);
dfp_subset_idx[['tf', 'pattern_center', 'strand', 'match_weighted', 'imp_weighted', 'score_seq_match']].sort_values('pattern_center')
for tf, pattern_name in pattern_names:
p = mr.get_pattern(pattern_name).trim_seq_ic(0.08)
p.name = "+" + tf
p.plot(['seq'] + [f'contrib/{t}' for t in tasks], rotate_y=0)
p.name = "-" + tf
p.rc().plot(['seq'] + [f'contrib/{t}' for t in tasks], rotate_y=0)
example_idx = dfp.example_idx.sample(1).iloc[0]
print(example_idx)
dfp_subset_idx = dfp[#(np.log(1+dfp.imp_max) > 2) &
(dfp.imp_cat == 'high') &
(dfp.match_cat == 'high') &
#(dfp.tf == "Errb") &
(dfp.example_idx == example_idx) &
(dfp.pattern_start > 400) &
(dfp.pattern_end < 600)]
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(prefix_dict(hyp_contrib, "h/"), example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);
dfp_subset_idx[['tf', 'pattern_center', 'strand', 'match_weighted', 'imp_weighted', 'score_seq_match']].sort_values('pattern_center')
dfp_subset_idx = dfp[#(np.log(1+dfp.imp_max) > 2) &
(dfp.imp_cat == 'high') &
(dfp.match_cat == 'high') &
(dfp.tf == "Sox2") &
(dfp.example_idx == example_idx) &
(dfp.pattern_start > 400) &
(dfp.pattern_end < 600)]
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(prefix_dict(hyp_contrib, "h/"), example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);
dfp_subset_idx = dfp[#(np.log(1+dfp.imp_max) > 2) &
(dfp.imp_cat == 'high') &
(dfp.match_cat == 'high') &
(dfp.tf == "Oct4-Sox2") &
(dfp.example_idx == example_idx) &
(dfp.pattern_start > 400) &
(dfp.pattern_end < 600)]
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(prefix_dict(hyp_contrib, "h/"), example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);
dfp_subset_idx = dfp[#(np.log(1+dfp.imp_max) > 2) &
(dfp.imp_cat == 'high') &
(dfp.match_cat == 'high') &
(dfp.tf == "Nanog") &
(dfp.example_idx == example_idx) &
(dfp.pattern_start > 400) &
(dfp.pattern_end < 600)]
seqlets = [Seqlet(row.example_idx, row.pattern_start-400, row.pattern_end-400, row.tf, row.strand)
for i,row in dfp_subset_idx.iterrows()]
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(profile, example_idx)}, xlim=[400, 600]),
rotate_y=0, legend=True);
plot_tracks(filter_tracks({**get_dataset_item(contrib, example_idx),
**get_dataset_item(prefix_dict(hyp_contrib, "h/"), example_idx)}, xlim=[400, 600]),
seqlets, rotate_y=0, legend=True);