from basepair.imports import *
ddir = get_data_dir()
from basepair.plot.profiles import extract_signal
from basepair.math import softmax
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
create_tf_session(0)
ls {model_dir}
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
# Load the data
d = HDF5Reader(model_dir / "grad.all.h5")
d.open()
seq = d.f['/inputs'][:]
# Load the model
model = load_model(model_dir / "model.h5")
ds = DataSpec.load(model_dir / "dataspec.yaml")
preds = model.predict(seq, verbose=1)
HDF5BatchWriter.dump(model_dir / "preds.all.h5", preds)
modisco_pdir = model_dir / "modisco/by_peak_tasks/weighted/"
def get_profile(task, pattern, preds, ds):
mr = ModiscoResult(modisco_pdir / f"{task}/modisco.h5")
mr.open()
seqlets = mr._get_seqlets(pattern)
mr.close()
include_samples = np.load(read_json(modisco_pdir / f"{task}/kwargs.json")["filter_npy"])
profile_preds = softmax(preds[ds.task2idx(task, 'profile')][include_samples])
count_preds = np.exp(preds[ds.task2idx(task, 'counts')][include_samples]) - 1
profile_obs = d.f[f'/targets/profile/{task}'][:][include_samples]
seqlet_profile_obs = extract_signal(profile_obs, seqlets)
seqlet_profile_pred = extract_signal(profile_preds, seqlets)
seqlet_idx = np.array([s.seqname for s in seqlets])
total_counts = seqlet_profile_obs.sum(axis=-1).sum(axis=-1)
sort_idx = np.argsort(-total_counts)
# Add the counts from the region
pred_counts = seqlet_profile_pred * count_preds[seqlet_idx][:, np.newaxis]
return seqlet_profile_obs, pred_counts, sort_idx
def scatter_counts(obs, pred):
fig = plt.figure(figsize=(8,4))
plt.subplot(121)
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx],
obs.sum(axis=-1).sum(axis=-1)[sort_idx], ".", alpha=0.5)
plt.title("Scatterplot at the log-scale")
plt.xlabel("Predicted")
plt.ylabel("Observed")
plt.subplot(122)
plt.plot(np.log(1+pred.sum(axis=-1).sum(axis=-1)[sort_idx]),
np.log(1+ obs.sum(axis=-1).sum(axis=-1)[sort_idx]), ".", alpha=0.5)
plt.title("Scatterplot at the natural scale")
plt.xlabel("Predicted")
plt.ylabel("Observed")
plt.tight_layout()
task = "Klf4"
pattern = "metacluster_0/pattern_0"
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
"Predicted": pred});
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
scatter_counts(obs, pred)
multiple_heatmap_stranded_profile({"Observed": obs,
"Predicted": pred},
sort_idx=sort_idx[:1000],
figsize=(10, 20));
task = "Nanog"
pattern = "metacluster_0/pattern_2"
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
"Predicted": pred});
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
scatter_counts(obs, pred)
multiple_heatmap_stranded_profile({"Observed": obs,
"Predicted": pred},
sort_idx=sort_idx[:1000],
figsize=(10, 20));
task = "Oct4"
pattern = "metacluster_0/pattern_0"
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
"Predicted": pred});
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
scatter_counts(obs, pred)
multiple_heatmap_stranded_profile({"Observed": obs,
"Predicted": pred},
sort_idx=sort_idx[:1000],
figsize=(10, 20));
task = "Sox2"
pattern = "metacluster_0/pattern_1"
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
"Predicted": pred});
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
scatter_counts(obs, pred)
multiple_heatmap_stranded_profile({"Observed": obs,
"Predicted": pred},
sort_idx=sort_idx[:1000],
figsize=(10, 20));