Goal

  • setup a model with a larger receptive field
In [1]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import basepair
from basepair.imports import *
# hv.extension('bokeh')
Using TensorFlow backend.
In [2]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
In [3]:
create_tf_session(0)
Out[3]:
<tensorflow.python.client.session.Session at 0x7f4a2f09f9b0>

Tasks

  • load the data
In [4]:
ls {model_dir}
bottleneck.dataset.pkl  figures/        Intervene_results/  preds.h5
clustering/             grad.all.h5     ism.all.h5          preds.test.bak.pkl
cometml.json            grad.test.2.h5  log/                preds.test.pkl
count-models/           grad.test.h5    model.h5            preds.valid.pkl
dataspec.yaml           grad.valid.h5   model.unscaled.h5   results.html
deeplift.all.h5         history.csv     modisco/            results.ipynb
evaluate/               hparams.yaml    preds.all.h5
In [5]:
from basepair.cli.schemas import HParams
from basepair.datasets import get_StrandedProfile_datasets
from basepair.models import seq_bpnet_cropped, seq_bpnet_cropped_extra_seqlen
In [6]:
ds = DataSpec.load(model_dir / "dataspec.yaml")
hp = HParams.load(model_dir / "hparams.yaml")
In [7]:
for k,v in hp.model.kwargs.items(): print(f"{k} = {v}")
filters = 64
conv1_kernel_size = 25
tconv_kernel_size = 25
n_dil_layers = 9
lr = 0.004
c_task_weight = 10
In [8]:
# hparams
filters = 64
conv1_kernel_size = 25
tconv_kernel_size = 25
n_dil_layers = 9
lr = 0.004
c_task_weight = 10
batchnorm = False
In [ ]:
!cat {model_dir}/hparams.yml
In [9]:
m = seq_bpnet_cropped(['Sox2'],
                      filters=filters,
                      conv1_kernel_size=conv1_kernel_size,
                      tconv_kernel_size=tconv_kernel_size,
                      tconv_n_hidden=0,
                      n_dil_layers=n_dil_layers,
                      lr=lr,
                      batchnorm=batchnorm,
                      c_task_weight=c_task_weight,
                      use_profile=True,
                      use_counts=True,
                      outputs_per_task=2,
                      task_use_bias=False,
                      profile_loss='mc_multinomial_nll',
                      count_loss='mse',
                      )
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-01-22 16:43:16,849 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-01-22 16:43:26,691 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [10]:
# compute the additional required sequence length
add_seqlen = seq_bpnet_cropped_extra_seqlen(conv1_kernel_size, 
                      n_dil_layers,
                      tconv_kernel_size)
add_seqlen
Out[10]:
2092
In [2]:
from basepair.models import seq_bpnet_cropped_extra_seqlen
Using TensorFlow backend.
In [5]:
# compute the additional required sequence length
add_seqlen = seq_bpnet_cropped_extra_seqlen(25, 
                      12,
                      25)
add_seqlen
Out[5]:
16428
In [15]:
# compute the additional required sequence length
add_seqlen = seq_bpnet_cropped_extra_seqlen(25, 
                      11,
                      25)
add_seqlen
Out[15]:
8236
In [11]:
print(m.predict(np.ones((1, 1000 + add_seqlen, 4)))[0].shape)
(1, 1000, 2)
In [12]:
m.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
seq (InputLayer)                (None, None, 4)      0                                            
__________________________________________________________________________________________________
conv_model (Model)              (None, None, 64)     117632      seq[0][0]                        
__________________________________________________________________________________________________
cropped_deconv_1d (Sequential)  (None, None, 2)      3202        conv_model[1][0]                 
__________________________________________________________________________________________________
global_average_pooling1d_1 (Glo (None, 64)           0           conv_model[1][0]                 
__________________________________________________________________________________________________
profile/Sox2 (Lambda)           (None, None, 2)      0           cropped_deconv_1d[1][0]          
__________________________________________________________________________________________________
counts/Sox2 (Dense)             (None, 2)            130         global_average_pooling1d_1[0][0] 
==================================================================================================
Total params: 120,964
Trainable params: 120,964
Non-trainable params: 0
__________________________________________________________________________________________________
In [13]:
train, valid, test = get_StrandedProfile_datasets(ds, 
                                                  peak_width=1000,
                                                  seq_width=1000 + add_seqlen,
                                                  shuffle=True,
                                                  valid_chr=['chr2', 'chr3', 'chr4'],
                                                  test_chr=['chr1', 'chr8', 'chr9'])
In [14]:
train_all = train.load_all(num_workers=20)
valid_all = valid.load_all(num_workers=20)
100%|██████████| 1913/1913 [00:09<00:00, 200.10it/s]
100%|██████████| 599/599 [00:03<00:00, 154.83it/s]

TODO

  • [x] compute the required sequence width given the hyper-parameters (like the dilation)
  • [x] request the dataset with that size
  • [ ] train nd evaluate a simple model
  • [ ] try out training the model directly on the counts using the poisson loss
In [22]:
output_dir = Path("/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/extended_exp")
In [23]:
from uuid import uuid4
In [35]:
from keras.callbacks import EarlyStopping, History, ModelCheckpoint
In [ ]:
# get the experiemnt directory
exp_dir = output_dir / str(uuid4())[:8]
exp_dir.mkdir(parents=True, exist_ok=True)
ckp_file = str(exp_dir / 'model.h5')
history = m.fit(train_all['inputs'], 
                train_all['targets'],
          batch_size=256, 
          epochs=100,
          validation_data=(valid_all['inputs'], valid_all['targets']),
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file)
[autoreload of basepair.datasets failed: Traceback (most recent call last):
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 244, in check
    superreload(m, reload, self.old_objects)
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 376, in superreload
    module = reload(module)
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/imp.py", line 315, in reload
    return importlib.reload(module)
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/importlib/__init__.py", line 166, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 618, in _exec
  File "<frozen importlib._bootstrap_external>", line 678, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/users/avsec/workspace/basepair/basepair/datasets.py", line 82, in <module>
    test_chr=test_chr):
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/gin/config.py", line 1129, in configurable
    return perform_decoration(decoration_target)
  File "/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/gin/config.py", line 1126, in perform_decoration
    return _make_configurable(fn_or_cls, name, module, whitelist, blacklist)
ValueError: A configurable matching 'basepair.datasets.chip_exo_nexus' already exists.
]
Train on 61205 samples, validate on 19137 samples
Epoch 1/100
61205/61205 [==============================] - 142s 2ms/step - loss: 601.3879 - profile/Sox2_loss: 597.8852 - counts/Sox2_loss: 0.3503 - val_loss: 607.2508 - val_profile/Sox2_loss: 604.0569 - val_counts/Sox2_loss: 0.3194
Epoch 2/100
61205/61205 [==============================] - 143s 2ms/step - loss: 595.8152 - profile/Sox2_loss: 592.4343 - counts/Sox2_loss: 0.3381 - val_loss: 603.2191 - val_profile/Sox2_loss: 599.8717 - val_counts/Sox2_loss: 0.3347
Epoch 3/100
61205/61205 [==============================] - 143s 2ms/step - loss: 590.4445 - profile/Sox2_loss: 587.1077 - counts/Sox2_loss: 0.3337 - val_loss: 599.2909 - val_profile/Sox2_loss: 596.0872 - val_counts/Sox2_loss: 0.3204
Epoch 4/100
61205/61205 [==============================] - 143s 2ms/step - loss: 585.9269 - profile/Sox2_loss: 582.6367 - counts/Sox2_loss: 0.3290 - val_loss: 594.9118 - val_profile/Sox2_loss: 591.8010 - val_counts/Sox2_loss: 0.3111
Epoch 5/100
30208/61205 [=============>................] - ETA: 1:06 - loss: 586.7940 - profile/Sox2_loss: 583.3527 - counts/Sox2_loss: 0.3441