Goal

  • build a model to predict the exonuclease cuts given the sequence (fully seq->seq)

Design

Schema

  • input: 200x4 (one-hot encoded DNA sequence)
  • output: 200x2 (chip-exo counts, normalized/softmax across each channel)

Loss function

  • Use multinomial log-likelihood
    • this will control for the amount of reads at each position

Expected dificulties

  • the netowork might be kind of correct but mis the signal for a single base.
    • Surprizingly the network is able to predict the peaks very accurately so this is not a real problem

Architecture

  • U-net like architecture with Transposed-Conv/upscale layers

Data split

  • 60
  • 20
  • 20

TODO

  • [x] get the data
  • [x] split the data
    • write the data function
  • [x] try out the simplest possible architecture
    • one conv layer
    • batch norm?
    • one de-conv layer (with large enough window size)
    • try motif initialization for the two motifs
  • [x] visualize the weights
    • de-conv layer -> ala averaged signal over the motifs?
  • [x] evaluate the accuracy on the test-set
  • [x] visualize the prediction results for:
    • most correct
    • most wrong
    • most uncertain even though it should be correct
  • [ ] evaluate the performance for predicting high-count peaks
    • how many can we predict correctly - with base-pair precision?
In [1]:
%env CUDA_VISIBLE_DEVICES=3
env: CUDA_VISIBLE_DEVICES=3
In [2]:
import os
In [3]:
os.environ['CUDA_VISIBLE_DEVICES']
Out[3]:
'3'
In [4]:
import tensorflow as tf
/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
In [5]:
#tf.enable_eager_execution()
import keras
Using TensorFlow backend.

get the data and split it into train,test

In [16]:
from basepair.config import get_data_dir

import pyBigWig
import matplotlib.pyplot as plt
from concise.preprocessing import encodeDNA
from concise.utils.fasta import read_fasta
from pybedtools import BedTool
from tqdm import tqdm
import pandas as pd
import numpy as np

from basepair.data import get_sox2_data, seq_inp_exo_out
from basepair.math import softmax
In [10]:
keras.backend.get_session().list_devices()
Out[10]:
[_DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 268435456),
 _DeviceAttributes(/job:localhost/replica:0/task:0/device:GPU:0, GPU, 11613519872)]
In [12]:
ddir = get_data_dir()
In [13]:
dfc = get_sox2_data()
100%|██████████| 9396/9396 [00:02<00:00, 3462.11it/s]
In [15]:
train, test = seq_inp_exo_out()
100%|██████████| 9396/9396 [00:02<00:00, 3246.21it/s]
In [17]:
train[0].shape
Out[17]:
(7445, 201, 4)
In [18]:
train[1].shape
Out[18]:
(7445, 201, 2)
In [19]:
test[0].shape
Out[19]:
(1951, 201, 4)
In [20]:
test[1].shape
Out[20]:
(1951, 201, 2)
In [21]:
train[1][2][:,0]
Out[21]:
array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
       0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.,
       1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)
In [22]:
y_batch = train[1][:64]
In [23]:
from basepair.losses import twochannel_multinomial_nll

try out the simplest possible architecture

  • one conv layer
  • batch norm?
  • one de-conv layer (with large enough window size)
  • try motif initialization for the two motifs
In [35]:
import concise.layers as cl
import concise.initializers as ci
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import Model, load_model
In [25]:
motifs = {"motif1": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif1.motif",
          "motif2": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif2.motif"}
In [26]:
from basepair.motif.homer import load_motif, read_motif_hits
pwm_list = [load_motif(fname) for k,fname in motifs.items()]
In [27]:
for i,pwm in enumerate(pwm_list):
    pwm.plotPWMInfo((5,1.5))
    plt.title(f"motif{i+1}")
In [32]:
# TODO - build the reverse complement symmetry into the model
inp = cl.Input(shape=(201, 4))
filters = 32
first_conv = cl.Conv1D(filters, kernel_size=21, padding='same', 
               #kernel_initializer = ci.PSSMKernelInitializer(pwm_list, stddev=0, add_noise_before_Pwm2Pssm=False),
               #bias_initializer = 'zeros',
               activation='relu')(inp)

second_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2)(first_conv)
c_second_conv = kl.add([first_conv, second_conv])
third_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=4)(c_second_conv)
c_third_conv = kl.add([first_conv, second_conv, third_conv])
fourth_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=8)(c_third_conv)
c_fourth_conv = kl.add([first_conv, second_conv, third_conv, fourth_conv])
fifth_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=16)(c_fourth_conv)
c_fifth_conv = kl.add([first_conv, second_conv, third_conv, fourth_conv, fifth_conv])
sixt_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=32)(c_fifth_conv)
c_sixt_conv = kl.add([first_conv, second_conv, third_conv, fourth_conv, fifth_conv, sixt_conv])
seventh_conv = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=64)(c_sixt_conv)
combined_conv = kl.add([first_conv, second_conv, third_conv, fourth_conv, fifth_conv, sixt_conv, seventh_conv])

# Bottleneck layer
combined_conv = kl.Conv1D(3, kernel_size=1, padding='same', activation='relu')(combined_conv)
x = kl.Reshape((-1, 1, 3))(combined_conv)
x = kl.Conv2DTranspose(2, kernel_size=(25, 1), padding='same')(x)
    #kl.Conv2DTranspose(32, kernel_size=(7, 1), padding='same', activation='relu'),
    #kl.Conv2DTranspose(2, kernel_size=(3, 1), padding='same'),
out = kl.Reshape((-1, 2))(x)

model = Model(inp, out)
model.compile(Adam(lr=0.004), loss=twochannel_multinomial_nll)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-04-25 23:53:07,414 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

Notes

Experiments

  • resresnet_nconv=4_filters=256=_lr=0.004.h5 - 108238.6558
    • looks pretty good
  • resnest_allconnect_nconv=4_filters=256=_lr=0.004.h5 - 108588.4661
  • resnest_allconnect_nconv=4_filters=64_lr=0.004.h5 - 108546.4132 - Amazing results
  • resnest_allconnect_nconv=4_filters=32_lr=0.004.h5 108670.3445 - great looking as well
  • resnest_allconnect_nconv=5_filters=32_lr=0.004.h5 107394.5353 - even better
  • resnest_allconnect_nconv=6_filters=32_lr=0.004.h5 107187.9883 - even better
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=False.h5 109961.7913 - no dillation
  • resnest_allconnect_nconv=7_filters=64_lr=0.004_dilated=True.h5 - 107460.7238 - great results
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True-2.h5 - 107225.4570 - amazing results
  • resnest_allconnect_nconv=7_filters=16_lr=0.004_dilated=True.h5 - 107225.4570 - not as good as previous, but still cool

start messing around with output kernel size (before it was 25)

  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=13.h5 - 107437.4955 - qualitatively better looking
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=31.h5 - 107032.3784 - not so reproducible
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=21.h5 - 107032.3784 - similarly ok
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=41.h5 - 107032.3784 - similarly ok
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25.h5 - 106892.1230 - qualitatively ok
  • resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25,bottleneck=3.h5 - 109486.1274 - performance not as good, but hopefully more interepretable

Dillation helps

In [535]:
!mkdir -p {ddir}/processed/chipnexus/exp/models
In [34]:
# best valid so far: 108238.6558
ckp_file = f"{ddir}/processed/chipnexus/exp/models/resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25,bottleneck=3.h5"
model.fit(train[0], train[1], 
          batch_size=256, 
          epochs=200,
          validation_split=0.2,
          callbacks=[EarlyStopping(patience=5),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
Train on 5956 samples, validate on 1489 samples
Epoch 1/200
5956/5956 [==============================] - 7s 1ms/step - loss: 61666.5055 - val_loss: 133284.0984
Epoch 2/200
5956/5956 [==============================] - 1s 119us/step - loss: 61171.6427 - val_loss: 129908.5329
Epoch 3/200
5956/5956 [==============================] - 1s 121us/step - loss: 59764.0498 - val_loss: 125667.3962
Epoch 4/200
5956/5956 [==============================] - 1s 123us/step - loss: 59168.5766 - val_loss: 123435.7895
Epoch 5/200
5956/5956 [==============================] - 1s 116us/step - loss: 58567.4618 - val_loss: 119438.2346
Epoch 6/200
5956/5956 [==============================] - 1s 114us/step - loss: 57488.1259 - val_loss: 114553.8533
Epoch 7/200
5956/5956 [==============================] - 1s 112us/step - loss: 56842.4476 - val_loss: 112480.2617
Epoch 8/200
5956/5956 [==============================] - 1s 127us/step - loss: 56480.5747 - val_loss: 111478.7968
Epoch 9/200
5956/5956 [==============================] - 1s 110us/step - loss: 56292.9092 - val_loss: 111658.8188
Epoch 10/200
5956/5956 [==============================] - 1s 113us/step - loss: 56170.8366 - val_loss: 111151.7296
Epoch 11/200
5956/5956 [==============================] - 1s 116us/step - loss: 56070.2711 - val_loss: 110434.3509
Epoch 12/200
5956/5956 [==============================] - 1s 128us/step - loss: 56007.2057 - val_loss: 110654.2199
Epoch 13/200
5956/5956 [==============================] - 1s 111us/step - loss: 55997.6093 - val_loss: 110262.6591
Epoch 14/200
5956/5956 [==============================] - 1s 112us/step - loss: 55867.5115 - val_loss: 110503.7963
Epoch 15/200
5956/5956 [==============================] - 1s 120us/step - loss: 55843.2191 - val_loss: 109903.9520
Epoch 16/200
5956/5956 [==============================] - 1s 115us/step - loss: 55771.9299 - val_loss: 109862.8147
Epoch 17/200
5956/5956 [==============================] - 1s 111us/step - loss: 55763.6046 - val_loss: 109806.6797
Epoch 18/200
5956/5956 [==============================] - 1s 111us/step - loss: 55719.0659 - val_loss: 109849.5385
Epoch 19/200
5956/5956 [==============================] - 1s 116us/step - loss: 55706.9445 - val_loss: 109597.1517
Epoch 20/200
5956/5956 [==============================] - 1s 119us/step - loss: 55719.7071 - val_loss: 109845.6073
Epoch 21/200
5956/5956 [==============================] - 1s 124us/step - loss: 55731.8244 - val_loss: 109625.1842
Epoch 22/200
5956/5956 [==============================] - 1s 112us/step - loss: 55697.6796 - val_loss: 110154.3304
Epoch 23/200
5956/5956 [==============================] - 1s 120us/step - loss: 55615.1307 - val_loss: 109514.2247
Epoch 24/200
5956/5956 [==============================] - 1s 129us/step - loss: 55602.0515 - val_loss: 109486.1274
Epoch 25/200
5956/5956 [==============================] - 1s 124us/step - loss: 55603.2518 - val_loss: 109614.7043
Epoch 26/200
5956/5956 [==============================] - 1s 113us/step - loss: 55539.3342 - val_loss: 109727.2795
Epoch 27/200
5956/5956 [==============================] - 1s 122us/step - loss: 55521.4002 - val_loss: 109715.2905
Epoch 28/200
5956/5956 [==============================] - 1s 121us/step - loss: 55570.8776 - val_loss: 109799.4676
Epoch 29/200
5956/5956 [==============================] - 1s 136us/step - loss: 55536.4603 - val_loss: 109539.7961
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-34-d748d5d01147> in <module>()
      9          )
     10 # get the best model
---> 11 model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})

NameError: name 'load_model' is not defined

evaluate the accuracy on the test-set

visualize the prediction results for

  • most correct
  • most wrong
  • most uncertain even though it should be correct
In [37]:
import basepair
In [38]:
model.evaluate(test[0], test[1])
1951/1951 [==============================] - 1s 347us/step
Out[38]:
7918.008865645823
In [39]:
y_pred = softmax(model.predict(train[0]))
y_true = train[1]
np.allclose(y_pred.sum(axis=1),1)
Out[39]:
True
In [40]:
idx_list = pd.Series(np.arange(len(test[0]))).sample(10)
for idx in idx_list:
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
In [41]:
print("Pos")
print(np.argmax(y_pred[idx_list,:,0], axis=1))
print(np.argmax(y_true[idx_list,:,0], axis=1))
print("Neg")
print(np.argmax(y_pred[idx_list,:,1], axis=1))
print(np.argmax(y_true[idx_list,:,1], axis=1))
Pos
[112 154  41 139 120  10 143  39  39  31]
[  9 154  42  80  44  10  64   8  23 193]
Neg
[122 165  61 152 133 145 150 116  64 155]
[138  41  76   5  26  32 158  45  17   3]
In [42]:
y_pred = softmax(model.predict(test[0]))
y_true = test[1]
np.allclose(y_pred.sum(axis=1),1)
Out[42]:
True
In [43]:
idx_list = pd.Series(np.arange(len(test[0]))).sample(10)
for idx in idx_list:
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
In [ ]:
# Inspect the intermediate activations

visualize the weights

  • de-conv layer -> ala averaged signal over the motifs?
In [159]:
from concise.preprocessing.sequence import one_hot2string, DNA
In [315]:
one_hot2string(train[0][0][np.newaxis], DNA)[0]
Out[315]:
'CCAGGAAGCTGCCTCAGGCTAGCCTCCGGAAACACTCCACAGTATCAGAATTCATCCCTACAATCATCCTGAGTATGACTCCAGTTAAGGGTCACTAGGACATTCTTCTATAGCCCAATTCTACCTGCGTCTTACACACTGGCCCTGTAGCAGATACTAAAAAGCATTCTTTAAATCATTATTTCCAGGGAAATACATTAT'
In [316]:
# Add a softmax layer to the model
In [319]:
w=model.layers[-2].get_weights()[0]
In [320]:
w.shape
Out[320]:
(20, 1, 2, 2)
In [348]:
# Plot weights
w=model.layers[-2].get_weights()[0]
plt.figure(figsize=(20,3))
filters = w.shape[-1]
for i in range(filters):
    plt.subplot(1,filters, i+1)
    plt.plot(w[:,0,0,i], label='pos')
    plt.plot(w[:,0,1,i], label='neg')
    plt.title("Motif{}".format(i))
plt.legend();
In [349]:
from concise.utils.plot import seqlogo, seqlogo_fig
In [350]:
from concise.utils.pwm import pssm_array2pwm_array, _pwm2pwm_info
In [351]:
pssm_array2pwm_array
Out[351]:
<function concise.utils.pwm.pssm_array2pwm_array>
In [352]:
pwm_list[0].plotPSSM(figsize=(5,4));
In [353]:
seqlogo_fig(model.layers[0].get_weights()[0], figsize=(10,4), ncol=2);
In [ ]:
# GC gets strongly enhanced -> GC bias?
In [339]:
model.layers[0].plot_weights();
<Figure size 432x288 with 0 Axes>
In [340]:
seqlogo_fig(model.layers[0].get_weights()[0], figsize=(10,2), ncol=2);
In [55]:
l.plot_weights?
Signature: l.plot_weights(index=None, plot_type='motif_raw', figsize=None, ncol=1, **kwargs)
Docstring:
Plot filters as heatmap or motifs

index = can be a particular index or a list of indicies
**kwargs - additional arguments to concise.utils.plot.heatmap
File:      ~/bin/anaconda3/lib/python3.6/site-packages/concise/layers.py
Type:      method
In [ ]:
model.layers[0].plot_weights