Goal

  • Compare model predictions with the observed count profiles at the 4 strongest motifs

Conclusions

  • The model correctly learns the aggregated pattern
  • The correlation of the total counts is also good, but on the log-scale. On the natural scale, the model predictions don't vary that much
  • the heatmap visualization may be a bit misleading as pattern is not scaled by the total counts. Hence we don't see the decrease in the total count abundance
In [27]:
from basepair.imports import *
ddir = get_data_dir()
In [163]:
from basepair.plot.profiles import extract_signal
from basepair.math import softmax
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import  plot_stranded_profile, multiple_plot_stranded_profile
In [2]:
create_tf_session(0)
Out[2]:
<tensorflow.python.client.session.Session at 0x7f89a5abf400>
In [5]:
ls {model_dir}
clustering/    figures/        grad.test.h5   hparams.yaml        modisco/
cometml.json   grad.all.h5     grad.valid.h5  Intervene_results/  results.html
dataspec.yaml  grad.test.2.h5  history.csv    model.h5            results.ipynb
In [4]:
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
In [8]:
# Load the data
d = HDF5Reader(model_dir / "grad.all.h5")
d.open()
seq = d.f['/inputs'][:]
In [ ]:
# Load the model
model = load_model(model_dir / "model.h5")
In [34]:
ds = DataSpec.load(model_dir / "dataspec.yaml")
In [11]:
preds = model.predict(seq, verbose=1)
98428/98428 [==============================] - 31s 319us/step
In [14]:
HDF5BatchWriter.dump(model_dir / "preds.all.h5", preds)

Plot scatterplots for the main regions

In [17]:
modisco_pdir = model_dir / "modisco/by_peak_tasks/weighted/"

Functions

In [220]:
def get_profile(task, pattern, preds, ds):
    mr = ModiscoResult(modisco_pdir / f"{task}/modisco.h5")
    mr.open()
    seqlets = mr._get_seqlets(pattern)
    mr.close()
    include_samples = np.load(read_json(modisco_pdir / f"{task}/kwargs.json")["filter_npy"])
    profile_preds = softmax(preds[ds.task2idx(task, 'profile')][include_samples])
    count_preds = np.exp(preds[ds.task2idx(task, 'counts')][include_samples]) - 1

    profile_obs = d.f[f'/targets/profile/{task}'][:][include_samples]

    seqlet_profile_obs = extract_signal(profile_obs, seqlets)
    seqlet_profile_pred = extract_signal(profile_preds, seqlets)
    
    seqlet_idx = np.array([s.seqname for s in seqlets])
    
    total_counts = seqlet_profile_obs.sum(axis=-1).sum(axis=-1)
    sort_idx = np.argsort(-total_counts)

    # Add the counts from the region
    pred_counts = seqlet_profile_pred * count_preds[seqlet_idx][:, np.newaxis]

    return seqlet_profile_obs, pred_counts, sort_idx

def scatter_counts(obs, pred):

    fig = plt.figure(figsize=(8,4))
    plt.subplot(121)
    plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx], 
             obs.sum(axis=-1).sum(axis=-1)[sort_idx], ".", alpha=0.5)
    plt.title("Scatterplot at the log-scale")
    plt.xlabel("Predicted")
    plt.ylabel("Observed")
    plt.subplot(122)
    plt.plot(np.log(1+pred.sum(axis=-1).sum(axis=-1)[sort_idx]), 
             np.log(1+ obs.sum(axis=-1).sum(axis=-1)[sort_idx]), ".", alpha=0.5)
    plt.title("Scatterplot at the natural scale")
    plt.xlabel("Predicted")
    plt.ylabel("Observed")
    plt.tight_layout()
In [221]:
task = "Klf4"
pattern = "metacluster_0/pattern_0"
In [222]:
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
In [223]:
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
                                "Predicted": pred});
In [224]:
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
In [225]:
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
In [226]:
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
In [227]:
scatter_counts(obs, pred)
In [228]:
multiple_heatmap_stranded_profile({"Observed": obs, 
                                   "Predicted": pred},
                                  sort_idx=sort_idx[:1000], 
                                  figsize=(10, 20));
/users/avsec/workspace/basepair/basepair/plot/heatmaps.py:24: RuntimeWarning: invalid value encountered in true_divide
  snorms = np.minimum(s / p99[:,np.newaxis], 1)
In [229]:
task = "Nanog"
pattern = "metacluster_0/pattern_2"
In [230]:
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
In [231]:
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
                                "Predicted": pred});
In [232]:
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
In [233]:
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
In [234]:
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
In [235]:
scatter_counts(obs, pred)
In [236]:
multiple_heatmap_stranded_profile({"Observed": obs, 
                                   "Predicted": pred},
                                  sort_idx=sort_idx[:1000], 
                                  figsize=(10, 20));
/users/avsec/workspace/basepair/basepair/plot/heatmaps.py:24: RuntimeWarning: invalid value encountered in true_divide
  snorms = np.minimum(s / p99[:,np.newaxis], 1)
In [237]:
task = "Oct4"
pattern = "metacluster_0/pattern_0"
In [238]:
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
In [239]:
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
                                "Predicted": pred});
In [240]:
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
In [241]:
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
In [242]:
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
In [243]:
scatter_counts(obs, pred)
In [244]:
multiple_heatmap_stranded_profile({"Observed": obs, 
                                   "Predicted": pred},
                                  sort_idx=sort_idx[:1000], 
                                  figsize=(10, 20));
In [245]:
task = "Sox2"
pattern = "metacluster_0/pattern_1"
In [246]:
obs, pred, sort_idx = get_profile(task, pattern, preds, ds)
In [247]:
# on the same scale
multiple_plot_stranded_profile({"Observed": obs,
                                "Predicted": pred});
In [248]:
fig, axes = plt.subplots(2, 1, figsize=(4,4))
plt.tight_layout()
plot_stranded_profile(obs.mean(axis=0), ax=axes[0])
axes[0].set_title("Observed");
plot_stranded_profile(pred.mean(axis=0), ax=axes[1])
axes[1].set_title("Predicted");
In [249]:
plt.plot(obs.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Observed total counts");
In [250]:
plt.plot(pred.sum(axis=-1).sum(axis=-1)[sort_idx])
plt.xlabel("Sort index")
plt.title("Predicted total counts at the seqlet region");
In [251]:
scatter_counts(obs, pred)
In [252]:
multiple_heatmap_stranded_profile({"Observed": obs, 
                                   "Predicted": pred},
                                  sort_idx=sort_idx[:1000], 
                                  figsize=(10, 20));
/users/avsec/workspace/basepair/basepair/plot/heatmaps.py:24: RuntimeWarning: invalid value encountered in true_divide
  snorms = np.minimum(s / p99[:,np.newaxis], 1)