exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
gpu = 0
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.plot.config import get_figsize, paper_config
from basepair.extractors import bw_extract
import basepair
import pandas as pd
import numpy as np
from basepair.cli.schemas import DataSpec, TaskSpec
from pathlib import Path
from keras.models import load_model
from basepair.datasets import StrandedProfile
from basepair.preproc import AppendCounts
from basepair.losses import MultichannelMultinomialNLL
from basepair.config import valid_chr, test_chr
from basepair.plots import regression_eval, plot_loss
from basepair.plot.evaluate import regression_eval
from basepair.cli.evaluate import eval_profile
from basepair import samplers
from basepair.math import softmax
from basepair.exp.paper.config import *
import matplotlib.ticker as ticker
import warnings
warnings.filterwarnings("ignore")
# Use matplotlib paper config
paper_config()
# Common paths
model_dir = models_dir / exp
figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet"
# Parameters
model_file = model_dir / "model.h5"
dataspec_file = "../../chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml"
history_file = model_dir / "history.csv"
seq_width = 1000
num_workers = 10
ds = DataSpec.load(dataspec_file)
tasks = list(ds.task_specs)
create_tf_session(gpu)
from basepair.seqmodel import SeqModel
bpnet = SeqModel.from_mdir(model_dir)
bottleneck = bpnet.bottleneck_model()
profile_bias_pool_size=[1,50] # Note: this is specific to the model
# Get the predictions
dl_train = StrandedProfile(ds,
excl_chromosomes=valid_chr + test_chr,
peak_width=seq_width,
shuffle=False,
target_transformer=AppendCounts(),
taskname_first=True,
profile_bias_pool_size=profile_bias_pool_size)
train = dl_train.load_all(num_workers=num_workers)
dl_valid = StrandedProfile(ds,
incl_chromosomes=valid_chr,
peak_width=seq_width,
shuffle=False,
taskname_first=True,
target_transformer=AppendCounts(),
profile_bias_pool_size=profile_bias_pool_size)
valid = dl_valid.load_all(num_workers=num_workers)
# Compute the bottleneck features
train_bottlenecks = bottleneck.predict(train['inputs']['seq'])
valid_bottlenecks = bottleneck.predict(valid['inputs']['seq'])
from basepair.seqmodel import SeqModel
from basepair.layers import DilatedConv1D, DeConv1D, GlobalAvgPoolFCN
from basepair.metrics import BPNetMetricSingleProfile
from basepair.heads import ScalarHead, ProfileHead
from gin_train.metrics import ClassificationMetrics, RegressionMetrics
from basepair.losses import mc_multinomial_nll_2, CountsMultinomialNLL
from basepair.exp.paper.config import peak_pred_metric
from basepair.activations import clipped_exp
from basepair.functions import softmax
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from basepair.seqmodel import SeqModel
head = ScalarHead(target_name='{task}/counts',
net=GlobalAvgPoolFCN(n_tasks=2, batchnorm=False, hidden=[]),
activation=None,
loss='mse',
bias_input='bias/{task}/counts',
use_bias=True,
bias_shape=(2, ),
metric=RegressionMetrics(),
)
counts_model = SeqModel(body=lambda x: x,
heads=[head],
tasks=tasks,
optimizer=Adam(lr=0.004),
input_shape=train_bottlenecks.shape[1:],
input_name='bottleneck'
)
counts_model.model.fit(
{"bottleneck": train_bottlenecks, **train['inputs']}, train['targets'],
batch_size=1024,
epochs=100,
validation_data=({"bottleneck": valid_bottlenecks, **valid['inputs']}, valid['targets']),
callbacks=[EarlyStopping(patience=5, restore_best_weights=True)]
)
a=1
y_pred = counts_model.predict({"bottleneck": valid_bottlenecks, **valid['inputs']})
y_true = valid['targets']
# Common paths
model_dir = models_dir / exp
# figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet"
fdir = Path(f"{ddir}/figures/model-evaluation/chipnexus-bpnet/{exp}")
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
fig.savefig(fdir / 'calibrated,linear.total-counts.scatter-no-hidden.pdf')
for task in tasks:
fig, ax= plt.subplots(figsize=get_figsize(frac=0.25, aspect=1))
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
plt.minorticks_off()
# save the figure
os.makedirs(f"{fdir}/scatter", exist_ok=True)
fig.savefig(f"{fdir}/scatter/calibrated,linear.{task}.pdf")
fig.savefig(f"{fdir}/scatter/calibrated,linear.{task}.png")
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
fig.savefig(fdir / 'calibrated,hidden+bn.total-counts.scatter.pdf')
for task in tasks:
fig, ax= plt.subplots(figsize=get_figsize(frac=0.25, aspect=1))
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
plt.minorticks_off()
# save the figure
os.makedirs(f"{fdir}/scatter", exist_ok=True)
fig.savefig(f"{fdir}/scatter/calibrated,hidden+bn.{task}.pdf")
fig.savefig(f"{fdir}/scatter/calibrated,hidden+bn.{task}.png")
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
h = counts_model.heads[0]
h = counts_model.all_heads["Oct4"][0]
l = counts_model.model.layers[-1]
calibrated_dense_layers = {"Oct4": "dense_21",
"Sox2": "dense_23",
"Nanog": "dense_25",
"Klf4": "dense_27"}
calibrated_bias_layers = {"Oct4": "dense_22",
"Sox2": "dense_24",
"Nanog": "dense_26",
"Klf4": "dense_28"}
orig_dense_layers = {"Oct4": "dense_1",
"Sox2": "dense_3",
"Nanog": "dense_5",
"Klf4": "dense_7"}
orig_bias_layers = {"Oct4": "dense_2",
"Sox2": "dense_4",
"Nanog": "dense_6",
"Klf4": "dense_8"}
# calibrate the model
for tf in bpnet.tasks:
bpnet.model.get_layer(orig_bias_layers[tf]).set_weights(counts_model.model.get_layer(calibrated_bias_layers[tf]).get_weights())
bpnet.model.get_layer(orig_dense_layers[tf]).set_weights(counts_model.model.get_layer(calibrated_dense_layers[tf]).get_weights())
bpnet.save(str(model_dir / 'calibrated_seqmodel.pkl'))
y_pred = bpnet.predict(valid['inputs'])
y_true = valid['targets']
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
# fig.savefig(fdir / 'calibrated,linear.total-counts.scatter-no-hidden.pdf')
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'{task}/counts'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
yt = np.exp(y_true[f'counts/{task}'].mean(-1))
yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
xrange = [10, 1e4]
ax.set_ylim(xrange)
ax.set_xlim(xrange)
ax.plot(xrange, xrange, c='grey', alpha=0.2)
regression_eval(yt,
yp, alpha=.1, task=task, ax=ax, loglog=True)
ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
if i > 0:
ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()