Goal

Train a basic model predicting DNase cuts.

Tasks

  • [ ] get the training data
  • [ ] train single-task model
  • [ ] evaluate the single-task model
  • [ ] train multi-task model
  • [ ] evaluate the multi-task model
  • [ ] improve the model
    • [ ] larger context
    • [ ] valid padding
    • [ ] U-net structure
    • [ ] automatically balance betwen counts and the profile
    • [ ] RC parameter sharing or augmentation

get the training data

In [52]:
import basepair
import numpy as np
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from pathlib import Path
from basepair.config import create_tf_session, get_data_dir

ddir = get_data_dir()

create_tf_session(0)
Out[52]:
<tensorflow.python.client.session.Session at 0x7fa25abc57f0>
In [54]:
dpath = Path("/srv/scratch/avsec/workspace/basepair-workflow/data/DNase")

SEQ_WIDTH = 1000

ds = DataSpec(
    task_specs={"DNase": TaskSpec(task="DNase",
                                 pos_counts=dpath / "signal/raw/merged.pos.bw",
                                 neg_counts=dpath / "signal/raw/merged.neg.bw",
                                 peaks=dpath / "raw/peaks/ENCFF426TTH.bed.gz")},
    fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta",
              )
In [55]:
train, valid, test = chip_exo_nexus(ds, peak_width=SEQ_WIDTH)
100%|██████████| 322679/322679 [00:00<00:00, 1210600.49it/s]
2018-08-08 12:55:38,731 [INFO] extract sequence
2018-08-08 12:57:24,806 [INFO] extract counts
100%|██████████| 1/1 [02:31<00:00, 151.88s/it]
In [14]:
# TODO - preproc the data

train single-task model

In [62]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
from basepair.models import seq_multitask
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [57]:
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
seq (InputLayer)                (None, 1000, 4)      0                                            
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, 1000, 21)     1785        seq[0][0]                        
__________________________________________________________________________________________________
conv1d_9 (Conv1D)               (None, 1000, 21)     1344        conv1d_8[0][0]                   
__________________________________________________________________________________________________
add_7 (Add)                     (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
__________________________________________________________________________________________________
conv1d_10 (Conv1D)              (None, 1000, 21)     1344        add_7[0][0]                      
__________________________________________________________________________________________________
add_8 (Add)                     (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
__________________________________________________________________________________________________
conv1d_11 (Conv1D)              (None, 1000, 21)     1344        add_8[0][0]                      
__________________________________________________________________________________________________
add_9 (Add)                     (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
                                                                 conv1d_11[0][0]                  
__________________________________________________________________________________________________
conv1d_12 (Conv1D)              (None, 1000, 21)     1344        add_9[0][0]                      
__________________________________________________________________________________________________
add_10 (Add)                    (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
                                                                 conv1d_11[0][0]                  
                                                                 conv1d_12[0][0]                  
__________________________________________________________________________________________________
conv1d_13 (Conv1D)              (None, 1000, 21)     1344        add_10[0][0]                     
__________________________________________________________________________________________________
add_11 (Add)                    (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
                                                                 conv1d_11[0][0]                  
                                                                 conv1d_12[0][0]                  
                                                                 conv1d_13[0][0]                  
__________________________________________________________________________________________________
conv1d_14 (Conv1D)              (None, 1000, 21)     1344        add_11[0][0]                     
__________________________________________________________________________________________________
add_12 (Add)                    (None, 1000, 21)     0           conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
                                                                 conv1d_11[0][0]                  
                                                                 conv1d_12[0][0]                  
                                                                 conv1d_13[0][0]                  
                                                                 conv1d_14[0][0]                  
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 1000, 1, 21)  0           add_12[0][0]                     
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 1000, 1, 2)   2144        reshape_3[0][0]                  
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 1000, 2)      0           conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
global_average_pooling1d_2 (Glo (None, 21)           0           add_12[0][0]                     
__________________________________________________________________________________________________
profile/DNase (Lambda)          (None, 1000, 2)      0           reshape_4[0][0]                  
__________________________________________________________________________________________________
counts/DNase (Dense)            (None, 2)            44          global_average_pooling1d_2[0][0] 
==================================================================================================
Total params: 12,037
Trainable params: 12,037
Non-trainable params: 0
__________________________________________________________________________________________________
In [58]:
def get_model(mfn, mkwargs):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/dnase/exp/models/single-task"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(datetime.datetime.now()).replace(" ", "::")
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    return eval(mfn)(**mkwargs), name, ckp_file

train the model

  • add the training curve plot bellow training
In [66]:
# hyper-parameters
mfn = "seq_multitask"
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=51,
               n_dil_layers=6,
               tasks=['DNase'],
               seq_len=SEQ_WIDTH,
               lr=0.004)
In [ ]:
# best valid so far: 
model, name, ckp_file = get_model(mfn, mkwargs)
history = model.fit(train[0], 
                    train[1],
                    batch_size=256, 
                    epochs=100,
                    validation_data=valid[:2],
                    callbacks=[EarlyStopping(patience=5),
                               History(),
                               ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll, 
                                             "SpatialLifetimeSparsity": SpatialLifetimeSparsity})
Train on 205183 samples, validate on 61670 samples
Epoch 1/100
205183/205183 [==============================] - 52s 253us/step - loss: 1986.6356 - profile/DNase_loss: 1905.8155 - counts/DNase_loss: 0.8082 - val_loss: 1840.8780 - val_profile/DNase_loss: 1769.8586 - val_counts/DNase_loss: 0.7102
Epoch 2/100
205183/205183 [==============================] - 46s 223us/step - loss: 1822.9903 - profile/DNase_loss: 1748.8735 - counts/DNase_loss: 0.7412 - val_loss: 1802.6936 - val_profile/DNase_loss: 1732.6150 - val_counts/DNase_loss: 0.7008
Epoch 3/100
205183/205183 [==============================] - 46s 223us/step - loss: 1782.8926 - profile/DNase_loss: 1711.1990 - counts/DNase_loss: 0.7169 - val_loss: 1770.0586 - val_profile/DNase_loss: 1700.4673 - val_counts/DNase_loss: 0.6959
Epoch 4/100
205183/205183 [==============================] - 45s 219us/step - loss: 1761.2130 - profile/DNase_loss: 1691.2021 - counts/DNase_loss: 0.7001 - val_loss: 1768.1815 - val_profile/DNase_loss: 1694.2280 - val_counts/DNase_loss: 0.7395
Epoch 5/100
205183/205183 [==============================] - 46s 222us/step - loss: 1748.5447 - profile/DNase_loss: 1679.9672 - counts/DNase_loss: 0.6858 - val_loss: 1739.2853 - val_profile/DNase_loss: 1675.1188 - val_counts/DNase_loss: 0.6417
Epoch 6/100
205183/205183 [==============================] - 45s 218us/step - loss: 1737.4753 - profile/DNase_loss: 1669.9106 - counts/DNase_loss: 0.6756 - val_loss: 1741.5786 - val_profile/DNase_loss: 1672.2618 - val_counts/DNase_loss: 0.6932
Epoch 7/100
205183/205183 [==============================] - 44s 216us/step - loss: 1730.9541 - profile/DNase_loss: 1664.3759 - counts/DNase_loss: 0.6658 - val_loss: 1737.2615 - val_profile/DNase_loss: 1665.1475 - val_counts/DNase_loss: 0.7211
Epoch 8/100
205183/205183 [==============================] - 43s 211us/step - loss: 1727.2633 - profile/DNase_loss: 1660.7070 - counts/DNase_loss: 0.6656 - val_loss: 1732.6612 - val_profile/DNase_loss: 1667.7233 - val_counts/DNase_loss: 0.6494
Epoch 9/100
205183/205183 [==============================] - 44s 215us/step - loss: 1726.4051 - profile/DNase_loss: 1660.3897 - counts/DNase_loss: 0.6602 - val_loss: 1739.0704 - val_profile/DNase_loss: 1663.7556 - val_counts/DNase_loss: 0.7531
Epoch 10/100
205183/205183 [==============================] - 44s 213us/step - loss: 1719.7151 - profile/DNase_loss: 1654.7917 - counts/DNase_loss: 0.6492 - val_loss: 1718.9072 - val_profile/DNase_loss: 1658.6610 - val_counts/DNase_loss: 0.6025
Epoch 11/100
205183/205183 [==============================] - 45s 217us/step - loss: 1714.8451 - profile/DNase_loss: 1650.3448 - counts/DNase_loss: 0.6450 - val_loss: 1719.0784 - val_profile/DNase_loss: 1658.6771 - val_counts/DNase_loss: 0.6040
Epoch 12/100
205183/205183 [==============================] - 45s 218us/step - loss: 1715.1256 - profile/DNase_loss: 1651.1297 - counts/DNase_loss: 0.6400 - val_loss: 1720.2006 - val_profile/DNase_loss: 1660.2129 - val_counts/DNase_loss: 0.5999
Epoch 13/100
205183/205183 [==============================] - 44s 216us/step - loss: 1710.4905 - profile/DNase_loss: 1646.4333 - counts/DNase_loss: 0.6406 - val_loss: 1754.5781 - val_profile/DNase_loss: 1690.5193 - val_counts/DNase_loss: 0.6406
Epoch 14/100
205183/205183 [==============================] - 44s 213us/step - loss: 1708.8639 - profile/DNase_loss: 1645.1944 - counts/DNase_loss: 0.6367 - val_loss: 1711.8003 - val_profile/DNase_loss: 1652.9072 - val_counts/DNase_loss: 0.5889
Epoch 15/100
185856/205183 [==========================>...] - ETA: 3s - loss: 1705.7569 - profile/DNase_loss: 1642.5819 - counts/DNase_loss: 0.6317
In [68]:
a=1

evaluate the single-task model

In [106]:
from basepair.preproc import bin_counts

# TODO - write a generic plotter
class Seq2Profile:

    def __init__(self, x, y, model):
        self.x = x
        self.y = y
        self.model = model
        # Make the prediction
        self.y_pred = [softmax(y) for y in model.predict(x)]

    def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), binsize=1, fpath_template=None):
        import matplotlib.pyplot as plt
        if sort == 'random':
            idx_list = samplers.random(self.x, n)
        elif "_" in sort:
            kind, task = sort.split("_")
            
            if kind == "max":
                idx_list = samplers.top_max_count(self.y["profile/" + task], n)
            elif kind == "sum":
                idx_list = samplers.top_sum_count(self.y["profile/" + task], n)
            else:
                raise ValueError("")
        else:
            raise ValueError(f"sort={sort} couldn't be interpreted")
        for i, idx in enumerate(idx_list):
            task = "DNase"
            fig = plt.figure(figsize=figsize)
            plt.subplot(121)
            if i == 0:
                plt.title("Predicted DNase")
                bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0]
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], 
                     label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], 
                     label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
            plt.legend()
            plt.subplot(122)
            if i == 0:
                plt.title("Observed DNase")
            plt.plot(bin_counts(self.y["profile/" + task], binsize=binsize)[idx, :, 0], 
                     label='pos,m={}'.format(np.argmax(self.y["profile/" + task][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/" + task], binsize=binsize)[idx, :, 1], 
                     label='neg,m={}'.format(np.argmax(self.y["profile/" + task][idx, :, 1])))
            plt.legend()
In [70]:
from basepair.plots import regression_eval
from basepair.cli.evaluate import eval_profile
In [71]:
y_pred = model.predict(test[0])
In [78]:
y_true = test[1]
In [72]:
regression_eval(test[1]['counts/DNase'].mean(-1), y_pred[ds.task2idx("DNase", 'counts')].mean(-1))
In [73]:
task = "DNase"
In [76]:
from basepair.math import softmax
In [79]:
yp = softmax(y_pred[ds.task2idx(task, "profile")])
yt = y_true["profile/" + task]
In [80]:
df = eval_profile(yt, yp)
In [84]:
df
Out[84]:
auprc binsize frac_ambigous imbalance n_positives random_auprc
0 0.485028 1 0.018350 0.000286 31256 0.000518
1 0.493812 2 0.035386 0.000577 31017 0.001176
2 0.511556 4 0.065397 0.001172 30518 0.002653
3 0.552928 10 0.135557 0.003053 29410 0.008267
In [81]:
a=1
In [107]:
pl = Seq2Profile(test[0], test[1], model)
In [108]:
from basepair import samplers
In [109]:
samplers.top_sum_count(test[1]['profile/DNase'], 10)
Out[109]:
Int64Index([19193, 49091, 29040, 40750, 48092, 28112, 38541, 19386, 51931,
            31378],
           dtype='int64')
In [110]:
pl.plot(n=10, sort='sum_DNase', binsize=1)
In [111]:
pl.plot(n=10, sort='sum_DNase', binsize=10)
In [112]:
pl.plot(n=10, sort='max_DNase', binsize=1)
In [113]:
pl.plot(n=20, sort='random', binsize=1)
In [114]:
pl.plot(n=20, sort='random', binsize=10)

train count-only single-task model

evaluate count-only single-task model

improve the model

larger context

valid padding

U-net structure

RC parameter sharing or augmentation