from basepair.imports import *
from basepair.samplers import top_sum_count, top_max_count
from basepair.plot.tracks import plot_tracks, filter_tracks
from kipoi.data_utils import get_dataset_item
model_dir = Path(f"{ddir}/processed//chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / "modisco/by_peak_tasks/weighted/Oct4"
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
mr.plot_pssm("metacluster_0", "pattern_0", trim_frac=0.08);
mr.plot_pssm("metacluster_0", "pattern_2", trim_frac=0.08);
pattern_name = 'metacluster_0/pattern_0'
pattern = mr.get_pattern(pattern_name)
pattern.plot(rotate_y=0);
im = HDF5Reader(model_dir / "grad.valid.h5")
im.open()
d = im.load_all()
tasks = ['Oct4', 'Sox2', 'Klf4', 'Nanog']
all_hyp_contrib = (im.f['/grads/Oct4/weighted/0'][:] + im.f['/grads/Oct4/weighted/1'][:]) / 2
all_seq = im.f['/inputs/'][:]
all_contrib = all_hyp_contrib * all_seq
all_profile = im.f['/targets/profile/Oct4'][:]
idx = top_max_count(all_profile[:, 400:600], 80)
example_idx = 11128
gt = 'weighted'
hyp_contrib_scores = {f"{task}": mean(d['grads'][task][gt])
for task in tasks}
contrib_scores = {f"{task}": hyp_contrib_scores[f"{task}"] * all_seq
for task in tasks}
targets = {f"t/{task}": d['targets']['profile'][task] for task in tasks}
viz_dict = filter_tracks({**get_dataset_item(targets, example_idx), **get_dataset_item(contrib_scores, example_idx)}, xlim=[400, 600])
viz_dict.keys()
plot_tracks(viz_dict);
contrib = all_contrib[example_idx, 400:600]
hyp_contrib = all_hyp_contrib[example_idx, 400:600]
seq = all_seq[example_idx, 400:600]
t = dict(contrib=contrib,
hyp_contrib=hyp_contrib,
seq=seq)
# target
plot_tracks(dict(contrib=contrib, hyp_contrib=hyp_contrib));
from basepair.modisco.results import *
pattern.trim_seq_ic(0.08).plot(['seq', 'contrib/Oct4', 'hyp_contrib/Oct4'], rotate_y=0, height=2, letter_width=0.5);
pattern = pattern.trim_seq_ic(0.08)
pattern.scan_importance?
pattern = mr.get_pattern("metacluster_0/pattern_0").trim_seq_ic(0.08)
pattern.plot("seq");
match, importance = pattern.scan_importance(contrib_scores, hyp_contrib_scores, tasks, n_jobs=20)
# optional
seq_match = pattern.scan_seq(all_seq, n_jobs=20)
dfm = pattern.get_instances(match, importance, seq_match, fdr=0.01, verbose=True)
dfm.head()
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
]
dfl = []
for tf, pattern_name in pattern_names:
pattern = mr.get_pattern(pattern_name).trim_seq_ic(0.08)
match, importance = pattern.scan_importance(contrib_scores, hyp_contrib_scores, tasks,
n_jobs=20, verbose=True)
seq_match = pattern.scan_seq(all_seq, n_jobs=20, verbose=True)
dfm = pattern.get_instances(match, importance, seq_match, fdr=0.01, verbose=True)
dfm['tf'] = tf
dfl.append(dfm)
dfp = pd.concat(dfl)
!mkdir -p {ddir}/cache/dev/
!rmdir {ddir}/cache/dev/hit-scoring.feather
dfp.info()
dfp.to_hdf(f"{ddir}/cache/dev/hit-scoring.hdf5", "/scores")
!du -sh {ddir}/cache/dev/hit-scoring.hdf5
dfp.head()
seq_matchf = seq_match.reshape((-1, 2))
choose_max_strand = np.argmax(matchf[idx_list], axis=-1)
matchfs = matchf[idx_list, choose_max_strand]
importancefs = importancef[idx_list]
seq_matchfs = seq_matchf[idx_list, choose_max_strand]
sox2_pattern = mr.get_pattern("metacluster_0/pattern_2").trim_seq_ic(0.08)
sox2_pattern.plot("seq");
sox2_match, sox2_importance = sox2_pattern.scan_importance(contrib_scores, hyp_contrib_scores, tasks, n_jobs=20)
sox2_seq_match = sox2_pattern.scan_seq(all_seq, n_jobs=20)
match.shape
import random
idx_list = np.random.randint(0, 19137000, 100000)
plt.hist(match.flatten()[idx_list], 100);
match.shape
matchf = match.reshape((-1, 4, 2)).mean(axis=-2)
importancef = importance.reshape((-1, 4, 2)).mean(axis=-1)
seq_matchf = seq_match.reshape((-1, 2))
choose_max_strand = np.argmax(matchf[idx_list], axis=-1)
matchfs = matchf[idx_list, choose_max_strand]
importancefs = importancef[idx_list]
seq_matchfs = seq_matchf[idx_list, choose_max_strand]
np.sum(matchf.max(axis=-1) > 0.0673)
thr = fdr_threshold_norm_right(matchf.max(axis=-1), 99, fdr=0.001)
thr
keep = matchf.max(axis=-1) > thr
thr
plt.hist(importancef[keep,0].max(axis=-1), 100);
plt.plot(matchf[keep].max(axis=-1),importancef[keep,0], "." , alpha=0.05);
plt.ylabel("Importance")
plt.xlabel("Match");
g = sns.jointplot(matchf[keep].max(axis=-1), importancef[keep,0], alpha=.01);
g.set_axis_labels("Match", "Importance");
np.sum(importancef[keep,0] > 0.2)
plt.hist(importancef[keep,0], 100);
plt.hist(importancef.max(axis=-1), 100);
np.sum(matchf.max(axis=-1) > thr)
np.mean(matchf.max(axis=-1) > thr)
sox2_matchf = sox2_match.reshape((-1, 4, 2)).mean(axis=-2)
sox2_importancef = sox2_importance.reshape((-1, 4, 2))
sox2_seq_matchf = sox2_seq_match.reshape((-1, 2))
sox2_choose_max_strand = np.argmax(sox2_matchf[idx_list], axis=-1)
sox2_matchfs = sox2_matchf[idx_list, sox2_choose_max_strand]
sox2_importancefs = sox2_importancef[idx_list, :, sox2_choose_max_strand]
sox2_seq_matchfs = sox2_seq_matchf[idx_list, sox2_choose_max_strand]
seq_matchf.shape
matchf.shape
import seaborn as sns
sns.pairplot(pd.DataFrame(matchf[idx_list[:10000], :, 0], columns=tasks))
sns.pairplot(pd.DataFrame(importancefs, columns=tasks))
plt.hist(matchfs, 100);
plt.hist(sox2_matchfs, 100);
plt.hist(importancefs.max(axis=-1), 100);
plt.hist(sox2_importancefs.max(axis=-1), 100);
from basepair.stats import fdr_threshold_norm_right
fdr_threshold_norm_right(matchfs)
g = sns.jointplot(matchfs, importancefs[:, 0], alpha=.1);
g.set_axis_labels("Match", "Importance");
g = sns.jointplot(matchfs, importancefs.max(axis=-1), alpha=.1);
g.set_axis_labels("Match", "Importance");
fdr_threshold_norm_right(sox2_matchfs)
g = sns.jointplot(sox2_matchfs, sox2_importancefs[:, 0], alpha=.1);
g.set_axis_labels("Match", "Importance");
g = sns.jointplot(sox2_matchfs, sox2_importancefs.max(axis=-1), alpha=.1);
g.set_axis_labels("Match", "Importance");
g = sns.jointplot(matchfs, seq_matchfs, alpha=0.05);
g.set_axis_labels("Match", "PWM-scan");
fdr_threshold_norm_right(seq_matchfs)
stats.probplot(seq_matchfs, dist="norm", plot=plt);
plt.title("asd");
Interstingly, the PWM is a very poor predictor
g = sns.jointplot(matchf[idx_list, :,choose_max].mean(axis=-1), seq_matchf[idx_list,choose_max], alpha=0.05);
g.set_axis_labels("Match", "PWM-scan");
import scipy.stats as stats
stats.probplot(matchfs, dist="norm", plot=plt);
stats.probplot(sox2_matchfs, dist="norm", plot=plt);
np.sum(matchfs > threshold(matchfs) )
from scipy.stats import norm
norm.ppf(0.8, loc=loc, scale=scale)
loc, scale = norm.fit(matchfs)
loc, scale
upper = np.percentile(matchfs,99)
loc, scale = norm.fit(matchfs[(matchfs < upper)])
p = norm.cdf(matchfs, loc, scale)
keep, padj = fdrcorrection(1-p, alpha=0.1)
keep.sum()
keep.mean()
import scipy.stats as stats
#measurements = np.random.normal(loc = 20, scale = 5, size=100)
stats.probplot(matchfs[~keep], dist="norm", plot=plt);
import scipy.stats as stats
#measurements = np.random.normal(loc = 20, scale = 5, size=100)
stats.probplot(matchfs[(matchfs < upper)], dist="norm", plot=plt);
skip.mean()
plt.hist(p[skip], 100);
plt.hist(p[p> 0.9], 100);
plt.hist(padj, 100);
var
keep = matchfs > fdr_threshold_norm_right(matchfs)
plt.hist(matchfs[keep], 100);
g = sns.jointplot(matchfs[keep], importancefs[keep, 0], alpha=.1);
g.set_axis_labels("Match", "Importance");
g = sns.jointplot(matchfs[keep], importancefs[keep, 0], alpha=.1);
g.set_axis_labels("Match", "Importance");
plt.
hyp_contrib_score_scan.shape
np.all(contrib_score_scan==np.maximum(contrib_score_scan_fwd, contrib_score_scan_rev))
np.all(hyp_contrib_score_scan==np.maximum(hyp_contrib_score_scan_fwd, hyp_contrib_score_scan_rev))
plt.plot(np.ravel(pwm_scan_fwd)[:100000], np.ravel(contrib_score_scan_fwd)[:100000], ".")
plt.plot(np.ravel(pwm_scan_fwd)[:100000], np.ravel(hyp_contrib_score_scan_fwd)[:100000], ".")
plt.hist(np.ravel(pwm_scan_fwd), 100);
plt.xlabel("pwm scan score");
plt.plot(np.ravel(pwm_scan_fwd)[:100000], np.ravel(contrib_score_scan_fwd_scale)[:100000], ".")
plt.plot(np.ravel(contrib_score_scan_fwd)[:100000], np.ravel(contrib_score_scan_fwd_scale)[:100000], ".");
plt.xlabel("Match")
plt.ylabel("scale");
plt.hist(np.ravel(contrib_score_scan_fwd), 100);
plt.plot(np.ravel(contrib_score_scan_fwd)[:100000], np.ravel(hyp_contrib_score_scan_fwd)[:100000], ".")
plt.plot(np.ravel(contrib_score_scan_fwd_scale)[:100000], np.ravel(hyp_contrib_score_scan_fwd_scale)[:100000], ".")
viz_dict = dict(contrib=all_contrib[example_idx, 400:600],
hyp_contrib=all_hyp_contrib[example_idx, 400:600],
pwm_match=seq_match[example_idx, 400:600],
contrib_match=match[example_idx, 400:600, 0],
contrib_scale=importance[example_idx, 400:600, 0])
from basepair.modisco.results import Seqlet
pos_hits = np.argsort(-match[example_idx, 400:600, 0].max(axis=-1))[:2] # top 2 hits
match[example_idx, 400:600, 0].sum()
importance[example_idx, 400:600, 0].sum()
pos_hits
from basepair.utils import halve
def idx2seqlet(pos, pattern):
# TODO handle the reverse-complementation
i, j = halve(len(pattern))
return Seqlet(None, pos-i, pos+j, name=pattern.name)
seqlets = [idx2seqlet(pos, pattern) for pos in pos_hits]
pattern.plot("seq");
plot_tracks(viz_dict, seqlets, rotate_y=0, legend=True);
%timeit score_region_cont_jaccard(qa, ta)
%timeit score_region_cont_jaccard(qa, ta_all_contrib[:100])
%timeit score_region_cont_jaccard(qa, ta_all_contrib[:1000])
%timeit score_region_cont_jaccard(qa, ta_all_contrib[0]) # 1kb
%timeit score_full_regionarr_with_perpos_continjaccard(ta_all_contrib[0], qa) # 1kb
%timeit score_region_cont_jaccard(qa, ta_all_contrib[0][:, :100]) # 100bp
%timeit score_full_regionarr_with_perpos_continjaccard(ta_all_contrib[0][:, :100], qa) # 100bp
a, a_norm = score_full_regionarr_with_perpos_continjaccard(ta_all_contrib[0][:, :1000], qa)
b, b_norm = score_region_cont_jaccard(qa, ta_all_contrib[0][:, :1000]) # 100bp
assert b_norm[0] == a_norm[0]
%timeit np.stack([score_region_cont_jaccard(qa, ta_all_contrib[i]) for i in range(1000)])
scanned = np.stack(Parallel(10)(delayed(score_region_cont_jaccard)(qa, ta_all_contrib[i]) for i in tqdm(range(len(ta_all_contrib)))))
ta_all_contrib.shape
# TODO - add unit-tests
other_jacc = jaccard_sim_func(ta_all_contrib[0][:, :16][np.newaxis]*qa_L1_norm / a_norm[0], qa[np.newaxis])[0,0]
np.allclose(other_jacc, a[0])
np.allclose(other_jacc, b[0])
other_jacc = jaccard_sim_func(ta_all_contrib[0][:, 1:17][np.newaxis]*qa_L1_norm / b_norm[1], qa[np.newaxis])[0,0]
np.allclose(other_jacc, a[1])
np.allclose(other_jacc, b[1])
plt.scatter(a_norm, b_norm)
plt.scatter(np.ravel(a), np.ravel(b));
%timeit score_region_cont_jaccard(qa, ta)
%timeit score_full_regionarr_with_perpos_continjaccard(ta, qa)
from basepair.modisco.utils import ic_scale
def scan_pattern(pattern_grp, contrib_scores, hypothetical_contribs, one_hot, task_names, trim_frac=0, verbose=True):
"""
final modisco results scan structure
- sequence
- fwd
- rev
- <task>
- <imp score>
- match
- fwd
- rev
- scale
"""
# TODO - add i and j to the output?
pwm = pattern_grp['sequence']['fwd'][:]
trim_ij = trim_pssm_idx(ic_scale(pwm), frac=trim_frac)
pattern_len = len(pwm)
def get_group_match(grp, name):
for n in grp.keys():
if name in n:
return grp[n]
raise ValueError("{name} doesn't match any group keys: {l}".format(name=name, l=list(grp.keys())))
def parallel_score_continousjaccard_restructure(qa, ta, **kwargs):
fwd, fwd_scale = parallel_score_continousjaccard(qa, ta, **kwargs)
rev, rev_scale = parallel_score_continousjaccard(qa[::-1, ::-1], ta, **kwargs)
return dict(match=dict(fwd=fwd, rev=rev),
scale=dict(fwd=fwd_scale, rev=rev_scale),
)
trim_ij_rc = (pattern_len - trim_ij[1], pattern_len - trim_ij[0])
return dict(sequence=dict(fwd=parallel_score_pssm(pwm[trim_ij[0]:trim_ij[1]], one_hot, verbose=verbose),
rev=parallel_score_pssm(pwm[trim_ij[0]:trim_ij[1]][::-1, ::-1], one_hot, verbose=verbose)),
**{task: {"contrib_scores": parallel_score_continousjaccard_restructure(get_group_match(pattern_grp[task], "contrib_scores")['fwd'][trim_ij[0]:trim_ij[1]], contrib_scores),
"hypothetical_contribs": parallel_score_continousjaccard_restructure(get_group_match(pattern_grp[task], "hypothetical_contribs")['fwd'][trim_ij[0]:trim_ij[1]], hypothetical_contribs),}
for task in tqdm(task_names, disable=not verbose)}
)
tasks
mp1 = scan_pattern(mr.get_pattern_grp("metacluster_0", "pattern_0"), all_contrib, all_hyp_contrib, all_seq, tasks, trim_frac=0.08)
all_seq.size
idx = pd.Series(np.arange(all_seq.shape[0])).sample(10000).values
plt.plot(np.ravel(mp1['Oct4']['contrib_scores']['match']['fwd'])[idx], np.ravel(mp1['Sox2']['contrib_scores']['match']['fwd'])[idx], ".")
plt.plot(np.ravel(mp1['Oct4']['contrib_scores']['match']['fwd'])[idx], np.ravel(mp1['Nanog']['contrib_scores']['match']['fwd'])[idx], ".")
plt.plot(np.ravel(mp1['Oct4']['contrib_scores']['match']['fwd'])[idx], np.ravel(mp1['Klf4']['contrib_scores']['match']['fwd'])[idx], ".")
mp1['Oct4']['contrib_scores']['match']['fwd'].shape
orthogonal validation?
higher fraction of the
transposable elements are not conserved in Human
simulations -> profile gives only a sense of what can be going on
overlap with the histone data
question: