In [19]:
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
gpu = 0
In [3]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.plot.config import get_figsize, paper_config
from basepair.extractors import bw_extract
import basepair
import pandas as pd
import numpy as np
from basepair.cli.schemas import DataSpec, TaskSpec
from pathlib import Path
from keras.models import load_model
from basepair.datasets import StrandedProfile
from basepair.preproc import AppendCounts
from basepair.losses import MultichannelMultinomialNLL
from basepair.config import valid_chr, test_chr
from basepair.plots import regression_eval, plot_loss
from basepair.plot.evaluate import regression_eval
from basepair.cli.evaluate import eval_profile
from basepair import samplers 
from basepair.math import softmax
from basepair.exp.paper.config import *
import matplotlib.ticker as ticker
import warnings
warnings.filterwarnings("ignore")
# Use matplotlib paper config
paper_config()
Using TensorFlow backend.
In [4]:
# Common paths
model_dir = models_dir / exp
figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet"

# Parameters
model_file = model_dir / "model.h5"
dataspec_file = "../../chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml"
history_file = model_dir / "history.csv"
seq_width = 1000
num_workers = 10

Get predictions

In [5]:
ds = DataSpec.load(dataspec_file)
tasks = list(ds.task_specs)
In [6]:
create_tf_session(gpu)
Out[6]:
<tensorflow.python.client.session.Session at 0x7f1a764182e8>
In [7]:
from basepair.seqmodel import SeqModel
bpnet = SeqModel.from_mdir(model_dir)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-04-20 04:09:35,318 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-04-20 04:09:46,665 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [8]:
bottleneck = bpnet.bottleneck_model()
In [9]:
profile_bias_pool_size=[1,50]  # Note: this is specific to the model
In [10]:
# Get the predictions

dl_train = StrandedProfile(ds, 
                          excl_chromosomes=valid_chr + test_chr, 
                          peak_width=seq_width,
                          shuffle=False,
                          target_transformer=AppendCounts(),
                          taskname_first=True,
                          profile_bias_pool_size=profile_bias_pool_size)
train = dl_train.load_all(num_workers=num_workers)
dl_valid = StrandedProfile(ds, 
                          incl_chromosomes=valid_chr, 
                          peak_width=seq_width,
                          shuffle=False,
                          taskname_first=True,
                          target_transformer=AppendCounts(),
                          profile_bias_pool_size=profile_bias_pool_size)
valid = dl_valid.load_all(num_workers=num_workers)
100%|██████████| 2935/2935 [01:18<00:00, 37.16it/s]
100%|██████████| 915/915 [00:28<00:00, 32.60it/s]
In [11]:
# Compute the bottleneck features
train_bottlenecks = bottleneck.predict(train['inputs']['seq'])
valid_bottlenecks = bottleneck.predict(valid['inputs']['seq'])

Train the bottleneck model

  • train the bottleneck model for each separately
In [12]:
from basepair.seqmodel import SeqModel
from basepair.layers import DilatedConv1D, DeConv1D, GlobalAvgPoolFCN
from basepair.metrics import BPNetMetricSingleProfile
from basepair.heads import ScalarHead, ProfileHead
from gin_train.metrics import ClassificationMetrics, RegressionMetrics
from basepair.losses import mc_multinomial_nll_2, CountsMultinomialNLL
from basepair.exp.paper.config import peak_pred_metric
from basepair.activations import clipped_exp
from basepair.functions import softmax
In [13]:
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from basepair.seqmodel import SeqModel
In [43]:
head = ScalarHead(target_name='{task}/counts',
                  net=GlobalAvgPoolFCN(n_tasks=2, batchnorm=False, hidden=[]),
                          activation=None,
                          loss='mse',
                          bias_input='bias/{task}/counts',
                          use_bias=True,
                          bias_shape=(2, ),
                          metric=RegressionMetrics(),
                          )
In [45]:
counts_model = SeqModel(body=lambda x: x, 
                       heads=[head],
                        tasks=tasks,
                        optimizer=Adam(lr=0.004), 
                        input_shape=train_bottlenecks.shape[1:],
                        input_name='bottleneck'
                       )

Train the top model

In [ ]:
counts_model.model.fit(
  {"bottleneck": train_bottlenecks, **train['inputs']}, train['targets'],
  batch_size=1024,
  epochs=100,
  validation_data=({"bottleneck": valid_bottlenecks, **valid['inputs']}, valid['targets']),
  callbacks=[EarlyStopping(patience=5, restore_best_weights=True)]
)
Train on 93904 samples, validate on 29277 samples
Epoch 1/100
93904/93904 [==============================] - 34s 361us/step - loss: 60.1725 - Oct4/counts_loss: 1.2496 - Sox2/counts_loss: 17.9218 - Nanog/counts_loss: 19.9244 - Klf4/counts_loss: 21.0767 - val_loss: 7.6456 - val_Oct4/counts_loss: 0.5335 - val_Sox2/counts_loss: 1.8776 - val_Nanog/counts_loss: 2.4713 - val_Klf4/counts_loss: 2.7632
Epoch 2/100
93904/93904 [==============================] - 32s 342us/step - loss: 3.2756 - Oct4/counts_loss: 0.5317 - Sox2/counts_loss: 0.6413 - Nanog/counts_loss: 0.9914 - Klf4/counts_loss: 1.1112 - val_loss: 2.4862 - val_Oct4/counts_loss: 0.5128 - val_Sox2/counts_loss: 0.4568 - val_Nanog/counts_loss: 0.7511 - val_Klf4/counts_loss: 0.7655
Epoch 3/100
93904/93904 [==============================] - 30s 316us/step - loss: 2.3752 - Oct4/counts_loss: 0.5096 - Sox2/counts_loss: 0.4300 - Nanog/counts_loss: 0.6998 - Klf4/counts_loss: 0.7359 - val_loss: 2.4156 - val_Oct4/counts_loss: 0.4905 - val_Sox2/counts_loss: 0.4464 - val_Nanog/counts_loss: 0.7386 - val_Klf4/counts_loss: 0.7400
Epoch 4/100
93904/93904 [==============================] - 29s 306us/step - loss: 2.3211 - Oct4/counts_loss: 0.4855 - Sox2/counts_loss: 0.4250 - Nanog/counts_loss: 0.6933 - Klf4/counts_loss: 0.7174 - val_loss: 2.3613 - val_Oct4/counts_loss: 0.4646 - val_Sox2/counts_loss: 0.4421 - val_Nanog/counts_loss: 0.7327 - val_Klf4/counts_loss: 0.7220
Epoch 5/100
93904/93904 [==============================] - 28s 298us/step - loss: 2.2644 - Oct4/counts_loss: 0.4612 - Sox2/counts_loss: 0.4192 - Nanog/counts_loss: 0.6863 - Klf4/counts_loss: 0.6978 - val_loss: 2.3033 - val_Oct4/counts_loss: 0.4429 - val_Sox2/counts_loss: 0.4349 - val_Nanog/counts_loss: 0.7241 - val_Klf4/counts_loss: 0.7014
Epoch 6/100
93904/93904 [==============================] - 27s 287us/step - loss: 2.2066 - Oct4/counts_loss: 0.4377 - Sox2/counts_loss: 0.4128 - Nanog/counts_loss: 0.6789 - Klf4/counts_loss: 0.6772 - val_loss: 2.2471 - val_Oct4/counts_loss: 0.4212 - val_Sox2/counts_loss: 0.4281 - val_Nanog/counts_loss: 0.7162 - val_Klf4/counts_loss: 0.6815
Epoch 7/100
93904/93904 [==============================] - 26s 278us/step - loss: 2.1490 - Oct4/counts_loss: 0.4157 - Sox2/counts_loss: 0.4059 - Nanog/counts_loss: 0.6713 - Klf4/counts_loss: 0.6561 - val_loss: 2.1899 - val_Oct4/counts_loss: 0.3994 - val_Sox2/counts_loss: 0.4206 - val_Nanog/counts_loss: 0.7089 - val_Klf4/counts_loss: 0.6610
Epoch 8/100
93904/93904 [==============================] - 26s 272us/step - loss: 2.0930 - Oct4/counts_loss: 0.3955 - Sox2/counts_loss: 0.3986 - Nanog/counts_loss: 0.6637 - Klf4/counts_loss: 0.6351 - val_loss: 2.1339 - val_Oct4/counts_loss: 0.3820 - val_Sox2/counts_loss: 0.4116 - val_Nanog/counts_loss: 0.7005 - val_Klf4/counts_loss: 0.6399
Epoch 9/100
93904/93904 [==============================] - 26s 273us/step - loss: 2.0391 - Oct4/counts_loss: 0.3772 - Sox2/counts_loss: 0.3910 - Nanog/counts_loss: 0.6562 - Klf4/counts_loss: 0.6147 - val_loss: 2.0876 - val_Oct4/counts_loss: 0.3681 - val_Sox2/counts_loss: 0.4050 - val_Nanog/counts_loss: 0.6935 - val_Klf4/counts_loss: 0.6210
Epoch 10/100
93904/93904 [==============================] - 20s 218us/step - loss: 1.9879 - Oct4/counts_loss: 0.3608 - Sox2/counts_loss: 0.3831 - Nanog/counts_loss: 0.6489 - Klf4/counts_loss: 0.5950 - val_loss: 2.0354 - val_Oct4/counts_loss: 0.3523 - val_Sox2/counts_loss: 0.3955 - val_Nanog/counts_loss: 0.6858 - val_Klf4/counts_loss: 0.6019
Epoch 11/100
93904/93904 [==============================] - 20s 218us/step - loss: 1.9393 - Oct4/counts_loss: 0.3460 - Sox2/counts_loss: 0.3751 - Nanog/counts_loss: 0.6419 - Klf4/counts_loss: 0.5763 - val_loss: 1.9892 - val_Oct4/counts_loss: 0.3406 - val_Sox2/counts_loss: 0.3868 - val_Nanog/counts_loss: 0.6782 - val_Klf4/counts_loss: 0.5836
Epoch 12/100
93904/93904 [==============================] - 18s 189us/step - loss: 1.8939 - Oct4/counts_loss: 0.3329 - Sox2/counts_loss: 0.3669 - Nanog/counts_loss: 0.6352 - Klf4/counts_loss: 0.5588 - val_loss: 1.9466 - val_Oct4/counts_loss: 0.3308 - val_Sox2/counts_loss: 0.3779 - val_Nanog/counts_loss: 0.6711 - val_Klf4/counts_loss: 0.5668
Epoch 13/100
93904/93904 [==============================] - 20s 211us/step - loss: 1.8514 - Oct4/counts_loss: 0.3212 - Sox2/counts_loss: 0.3587 - Nanog/counts_loss: 0.6288 - Klf4/counts_loss: 0.5426 - val_loss: 1.9066 - val_Oct4/counts_loss: 0.3206 - val_Sox2/counts_loss: 0.3689 - val_Nanog/counts_loss: 0.6655 - val_Klf4/counts_loss: 0.5516
Epoch 14/100
93904/93904 [==============================] - 18s 195us/step - loss: 1.8117 - Oct4/counts_loss: 0.3107 - Sox2/counts_loss: 0.3506 - Nanog/counts_loss: 0.6228 - Klf4/counts_loss: 0.5276 - val_loss: 1.8708 - val_Oct4/counts_loss: 0.3119 - val_Sox2/counts_loss: 0.3602 - val_Nanog/counts_loss: 0.6599 - val_Klf4/counts_loss: 0.5388
Epoch 15/100
93904/93904 [==============================] - 19s 200us/step - loss: 1.7750 - Oct4/counts_loss: 0.3015 - Sox2/counts_loss: 0.3424 - Nanog/counts_loss: 0.6171 - Klf4/counts_loss: 0.5141 - val_loss: 1.8365 - val_Oct4/counts_loss: 0.3051 - val_Sox2/counts_loss: 0.3525 - val_Nanog/counts_loss: 0.6538 - val_Klf4/counts_loss: 0.5252
Epoch 16/100
93904/93904 [==============================] - 20s 214us/step - loss: 1.7409 - Oct4/counts_loss: 0.2931 - Sox2/counts_loss: 0.3344 - Nanog/counts_loss: 0.6117 - Klf4/counts_loss: 0.5017 - val_loss: 1.8035 - val_Oct4/counts_loss: 0.2993 - val_Sox2/counts_loss: 0.3429 - val_Nanog/counts_loss: 0.6480 - val_Klf4/counts_loss: 0.5133
Epoch 17/100
93904/93904 [==============================] - 18s 189us/step - loss: 1.7093 - Oct4/counts_loss: 0.2858 - Sox2/counts_loss: 0.3264 - Nanog/counts_loss: 0.6066 - Klf4/counts_loss: 0.4905 - val_loss: 1.7744 - val_Oct4/counts_loss: 0.2930 - val_Sox2/counts_loss: 0.3350 - val_Nanog/counts_loss: 0.6433 - val_Klf4/counts_loss: 0.5031
Epoch 18/100
93904/93904 [==============================] - 20s 209us/step - loss: 1.6801 - Oct4/counts_loss: 0.2793 - Sox2/counts_loss: 0.3187 - Nanog/counts_loss: 0.6017 - Klf4/counts_loss: 0.4804 - val_loss: 1.7457 - val_Oct4/counts_loss: 0.2876 - val_Sox2/counts_loss: 0.3263 - val_Nanog/counts_loss: 0.6385 - val_Klf4/counts_loss: 0.4932
Epoch 19/100
93904/93904 [==============================] - 22s 231us/step - loss: 1.6532 - Oct4/counts_loss: 0.2736 - Sox2/counts_loss: 0.3112 - Nanog/counts_loss: 0.5971 - Klf4/counts_loss: 0.4714 - val_loss: 1.7216 - val_Oct4/counts_loss: 0.2841 - val_Sox2/counts_loss: 0.3188 - val_Nanog/counts_loss: 0.6340 - val_Klf4/counts_loss: 0.4848
Epoch 20/100
93904/93904 [==============================] - 17s 178us/step - loss: 1.6284 - Oct4/counts_loss: 0.2686 - Sox2/counts_loss: 0.3038 - Nanog/counts_loss: 0.5927 - Klf4/counts_loss: 0.4633 - val_loss: 1.6969 - val_Oct4/counts_loss: 0.2798 - val_Sox2/counts_loss: 0.3108 - val_Nanog/counts_loss: 0.6293 - val_Klf4/counts_loss: 0.4769
Epoch 21/100
93904/93904 [==============================] - 20s 208us/step - loss: 1.6054 - Oct4/counts_loss: 0.2641 - Sox2/counts_loss: 0.2968 - Nanog/counts_loss: 0.5885 - Klf4/counts_loss: 0.4560 - val_loss: 1.6755 - val_Oct4/counts_loss: 0.2767 - val_Sox2/counts_loss: 0.3033 - val_Nanog/counts_loss: 0.6256 - val_Klf4/counts_loss: 0.4699
Epoch 22/100
93904/93904 [==============================] - 20s 211us/step - loss: 1.5841 - Oct4/counts_loss: 0.2604 - Sox2/counts_loss: 0.2899 - Nanog/counts_loss: 0.5844 - Klf4/counts_loss: 0.4494 - val_loss: 1.6567 - val_Oct4/counts_loss: 0.2734 - val_Sox2/counts_loss: 0.2972 - val_Nanog/counts_loss: 0.6223 - val_Klf4/counts_loss: 0.4638
Epoch 23/100
93904/93904 [==============================] - 18s 188us/step - loss: 1.5646 - Oct4/counts_loss: 0.2572 - Sox2/counts_loss: 0.2834 - Nanog/counts_loss: 0.5805 - Klf4/counts_loss: 0.4435 - val_loss: 1.6370 - val_Oct4/counts_loss: 0.2709 - val_Sox2/counts_loss: 0.2898 - val_Nanog/counts_loss: 0.6187 - val_Klf4/counts_loss: 0.4577
Epoch 24/100
93904/93904 [==============================] - 20s 211us/step - loss: 1.5463 - Oct4/counts_loss: 0.2545 - Sox2/counts_loss: 0.2771 - Nanog/counts_loss: 0.5766 - Klf4/counts_loss: 0.4381 - val_loss: 1.6196 - val_Oct4/counts_loss: 0.2690 - val_Sox2/counts_loss: 0.2831 - val_Nanog/counts_loss: 0.6149 - val_Klf4/counts_loss: 0.4525
Epoch 25/100
93904/93904 [==============================] - 19s 202us/step - loss: 1.5293 - Oct4/counts_loss: 0.2521 - Sox2/counts_loss: 0.2711 - Nanog/counts_loss: 0.5729 - Klf4/counts_loss: 0.4333 - val_loss: 1.6029 - val_Oct4/counts_loss: 0.2672 - val_Sox2/counts_loss: 0.2767 - val_Nanog/counts_loss: 0.6116 - val_Klf4/counts_loss: 0.4475
Epoch 26/100
93904/93904 [==============================] - 19s 199us/step - loss: 1.5137 - Oct4/counts_loss: 0.2502 - Sox2/counts_loss: 0.2653 - Nanog/counts_loss: 0.5693 - Klf4/counts_loss: 0.4289 - val_loss: 1.5888 - val_Oct4/counts_loss: 0.2656 - val_Sox2/counts_loss: 0.2717 - val_Nanog/counts_loss: 0.6084 - val_Klf4/counts_loss: 0.4431
Epoch 27/100
93904/93904 [==============================] - 20s 212us/step - loss: 1.4990 - Oct4/counts_loss: 0.2486 - Sox2/counts_loss: 0.2599 - Nanog/counts_loss: 0.5657 - Klf4/counts_loss: 0.4248 - val_loss: 1.5728 - val_Oct4/counts_loss: 0.2644 - val_Sox2/counts_loss: 0.2649 - val_Nanog/counts_loss: 0.6045 - val_Klf4/counts_loss: 0.4390
Epoch 28/100
93904/93904 [==============================] - 18s 194us/step - loss: 1.4852 - Oct4/counts_loss: 0.2471 - Sox2/counts_loss: 0.2547 - Nanog/counts_loss: 0.5622 - Klf4/counts_loss: 0.4212 - val_loss: 1.5603 - val_Oct4/counts_loss: 0.2641 - val_Sox2/counts_loss: 0.2596 - val_Nanog/counts_loss: 0.6014 - val_Klf4/counts_loss: 0.4353
Epoch 29/100
 8192/93904 [=>............................] - ETA: 13s - loss: 1.4789 - Oct4/counts_loss: 0.2510 - Sox2/counts_loss: 0.2478 - Nanog/counts_loss: 0.5536 - Klf4/counts_loss: 0.4265
In [48]:
a=1

Evaluate

In [51]:
y_pred = counts_model.predict({"bottleneck": valid_bottlenecks, **valid['inputs']})
y_true = valid['targets']

New model

In [52]:
# Common paths
model_dir = models_dir / exp
# figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet"
fdir = Path(f"{ddir}/figures/model-evaluation/chipnexus-bpnet/{exp}")
In [53]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
fig.savefig(fdir / 'calibrated,linear.total-counts.scatter-no-hidden.pdf')
In [54]:
for task in tasks:
    fig, ax= plt.subplots(figsize=get_figsize(frac=0.25, aspect=1))
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    plt.minorticks_off()
    # save the figure
    os.makedirs(f"{fdir}/scatter", exist_ok=True)
    fig.savefig(f"{fdir}/scatter/calibrated,linear.{task}.pdf")
    fig.savefig(f"{fdir}/scatter/calibrated,linear.{task}.png")

With hidden layer

In [49]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
fig.savefig(fdir / 'calibrated,hidden+bn.total-counts.scatter.pdf')
In [50]:
for task in tasks:
    fig, ax= plt.subplots(figsize=get_figsize(frac=0.25, aspect=1))
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    plt.minorticks_off()
    # save the figure
    os.makedirs(f"{fdir}/scatter", exist_ok=True)
    fig.savefig(f"{fdir}/scatter/calibrated,hidden+bn.{task}.pdf")
    fig.savefig(f"{fdir}/scatter/calibrated,hidden+bn.{task}.png")

Old model

In [26]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()

TODO

  • overwrite the parameters in the old model
In [59]:
h = counts_model.heads[0]
In [67]:
h = counts_model.all_heads["Oct4"][0]
In [82]:
l = counts_model.model.layers[-1]
In [113]:
calibrated_dense_layers = {"Oct4": "dense_21",
               "Sox2": "dense_23",
               "Nanog": "dense_25",
               "Klf4": "dense_27"}
In [114]:
calibrated_bias_layers = {"Oct4": "dense_22",
               "Sox2": "dense_24",
               "Nanog": "dense_26",
               "Klf4": "dense_28"}
In [115]:
orig_dense_layers = {"Oct4": "dense_1",
               "Sox2": "dense_3",
               "Nanog": "dense_5",
               "Klf4": "dense_7"}
In [116]:
orig_bias_layers = {"Oct4": "dense_2",
               "Sox2": "dense_4",
               "Nanog": "dense_6",
               "Klf4": "dense_8"}
In [121]:
# calibrate the model
for tf in bpnet.tasks:
    bpnet.model.get_layer(orig_bias_layers[tf]).set_weights(counts_model.model.get_layer(calibrated_bias_layers[tf]).get_weights())
    bpnet.model.get_layer(orig_dense_layers[tf]).set_weights(counts_model.model.get_layer(calibrated_dense_layers[tf]).get_weights())
In [123]:
bpnet.save(str(model_dir / 'calibrated_seqmodel.pkl'))

Test predictions

In [125]:
y_pred = bpnet.predict(valid['inputs'])
y_true = valid['targets']
In [126]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()
# fig.savefig(fdir / 'calibrated,linear.total-counts.scatter-no-hidden.pdf')

New model

In [70]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'{task}/counts'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()

Old model

In [14]:
fig, axes = plt.subplots(1, len(tasks), figsize=get_figsize(frac=1, aspect=1/len(tasks)),
                         sharex=True, sharey=True)
for i, (task, ax) in enumerate(zip(tasks, axes)):
    yt = np.exp(y_true[f'counts/{task}'].mean(-1))
    yp = np.exp(y_pred[f'{task}/counts'].mean(-1))
    xrange = [10, 1e4]
    ax.set_ylim(xrange)
    ax.set_xlim(xrange)
    
    ax.plot(xrange, xrange, c='grey', alpha=0.2)
    regression_eval(yt, 
                    yp, alpha=.1, task=task, ax=ax, loglog=True)
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=4))
    if i > 0:
        ax.set_ylabel("")
fig.subplots_adjust(wspace=0)
plt.minorticks_off()