Goal

  • train a decent seq -> counts model for ChIP-seq
In [1]:
import basepair
Using TensorFlow backend.
In [2]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
2018-07-03 14:53:49,690 [WARNING] git-lfs not installed
2018-07-03 14:53:49,703 [WARNING] git-lfs not installed
In [3]:
# Use gpus 1, 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 3, 5"
In [4]:
ddir = get_data_dir()
In [5]:
ddir
Out[5]:
'/srv/scratch/amr1/chipseq/basepair/basepair/../data'
In [6]:
bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [7]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [8]:
# Get the training data
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
100%|██████████| 23474/23474 [00:00<00:00, 1017854.77it/s]
2018-07-03 14:53:51,752 [INFO] extract sequence
2018-07-03 14:53:54,310 [INFO] extract counts
100%|██████████| 2/2 [00:14<00:00,  7.38s/it]
In [9]:
# Pre-process the data
preproc = AppendTotalCounts()
preproc.fit(train[1])
train[1] = preproc.transform(train[1])
valid[1] = preproc.transform(valid[1])
test[1] = preproc.transform(test[1])
In [10]:
train[1].keys()
Out[10]:
dict_keys(['profile/Sox2', 'profile/Oct4', 'counts/Sox2', 'counts/Oct4'])
In [11]:
# TODO - play around with this
def count_predict(filters=32,
                  conv1_kernel_size=21,
                  lr=0.004,
                  c_task_weight=100,
                  tasks=['sox2', 'oct4'],
                  dropout=0.1,
                  seq_len=201,
                  pool_type='avg'):

    def get_pool(pool_type):
        if pool_type=='max':
            return kl.MaxPool1D()
        elif pool_type =='avg':
            return kl.AveragePooling1D()
    
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    x = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)
    x = kl.Conv1D(filters, 1, activation='relu', padding='same')(x)
    x = get_pool(pool_type)(x)
    x = kl.Conv1D(2*filters, 7, activation='relu', padding='same')(x)
    x = get_pool(pool_type)(x)
    x = kl.Conv1D(2*filters, 7, activation='relu', padding='same')(x)
    x = kl.GlobalAveragePooling1D()(x)
    x = kl.Dense(8*filters, activation='relu')(x)
    x = kl.Dropout(dropout)(x)
    #x = kl.Dense(1)(x)
    #out = kl.Reshape((-1, 2 * len(tasks)))(x)

    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    counts = [kl.Dense(2, name="counts/" + task)(x)
                  for task in tasks]
    outputs += counts
    losses += ["mse"] * len(tasks)
    loss_weights += [c_task_weight] * len(tasks)

    model = Model(inp, outputs)
    model.compile(Adam(lr=lr), 
                  loss=losses, loss_weights=loss_weights)
    return model

setup a new model having two output branches

In [12]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
In [13]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [14]:
i=1
In [15]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipseq/exp/models/count"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file

train the model

  • add the training curve plot bellow training
In [16]:
# hyper-parameters
mfn = "count_predict"
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               c_task_weight=10,
               seq_len=train[0].shape[1],
               lr=0.0004)
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
In [17]:
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0], 
                    train[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid[:2],
          callbacks=[EarlyStopping(patience=10),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file)
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-07-03 14:54:11,628 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
Train on 14727 samples, validate on 4493 samples
Epoch 1/100
14727/14727 [==============================] - 11s 736us/step - loss: 20.0070 - counts/Sox2_loss: 1.0006 - counts/Oct4_loss: 1.0001 - val_loss: 20.3089 - val_counts/Sox2_loss: 1.0041 - val_counts/Oct4_loss: 1.0268
Epoch 2/100
14727/14727 [==============================] - 2s 122us/step - loss: 19.9895 - counts/Sox2_loss: 0.9993 - counts/Oct4_loss: 0.9997 - val_loss: 20.2455 - val_counts/Sox2_loss: 0.9984 - val_counts/Oct4_loss: 1.0262
Epoch 3/100
14727/14727 [==============================] - 2s 127us/step - loss: 19.8879 - counts/Sox2_loss: 0.9905 - counts/Oct4_loss: 0.9983 - val_loss: 20.0357 - val_counts/Sox2_loss: 0.9814 - val_counts/Oct4_loss: 1.0222
Epoch 4/100
14727/14727 [==============================] - 2s 118us/step - loss: 19.7420 - counts/Sox2_loss: 0.9799 - counts/Oct4_loss: 0.9943 - val_loss: 20.0133 - val_counts/Sox2_loss: 0.9786 - val_counts/Oct4_loss: 1.0228
Epoch 5/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.6828 - counts/Sox2_loss: 0.9749 - counts/Oct4_loss: 0.9933 - val_loss: 20.2074 - val_counts/Sox2_loss: 0.9904 - val_counts/Oct4_loss: 1.0303
Epoch 6/100
14727/14727 [==============================] - 2s 118us/step - loss: 19.5688 - counts/Sox2_loss: 0.9684 - counts/Oct4_loss: 0.9884 - val_loss: 19.8593 - val_counts/Sox2_loss: 0.9715 - val_counts/Oct4_loss: 1.0144
Epoch 7/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.5668 - counts/Sox2_loss: 0.9683 - counts/Oct4_loss: 0.9884 - val_loss: 19.8382 - val_counts/Sox2_loss: 0.9677 - val_counts/Oct4_loss: 1.0161
Epoch 8/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.4864 - counts/Sox2_loss: 0.9637 - counts/Oct4_loss: 0.9850 - val_loss: 19.7906 - val_counts/Sox2_loss: 0.9653 - val_counts/Oct4_loss: 1.0138
Epoch 9/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.4637 - counts/Sox2_loss: 0.9614 - counts/Oct4_loss: 0.9850 - val_loss: 19.9117 - val_counts/Sox2_loss: 0.9714 - val_counts/Oct4_loss: 1.0198
Epoch 10/100
14727/14727 [==============================] - 2s 117us/step - loss: 19.5119 - counts/Sox2_loss: 0.9644 - counts/Oct4_loss: 0.9868 - val_loss: 19.9884 - val_counts/Sox2_loss: 0.9761 - val_counts/Oct4_loss: 1.0228
Epoch 11/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.4695 - counts/Sox2_loss: 0.9618 - counts/Oct4_loss: 0.9852 - val_loss: 19.8731 - val_counts/Sox2_loss: 0.9680 - val_counts/Oct4_loss: 1.0193
Epoch 12/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.4483 - counts/Sox2_loss: 0.9599 - counts/Oct4_loss: 0.9849 - val_loss: 20.0539 - val_counts/Sox2_loss: 0.9807 - val_counts/Oct4_loss: 1.0247
Epoch 13/100
14727/14727 [==============================] - 2s 123us/step - loss: 19.4378 - counts/Sox2_loss: 0.9594 - counts/Oct4_loss: 0.9844 - val_loss: 19.7727 - val_counts/Sox2_loss: 0.9638 - val_counts/Oct4_loss: 1.0135
Epoch 14/100
14727/14727 [==============================] - 2s 124us/step - loss: 19.4107 - counts/Sox2_loss: 0.9581 - counts/Oct4_loss: 0.9829 - val_loss: 19.7822 - val_counts/Sox2_loss: 0.9635 - val_counts/Oct4_loss: 1.0147
Epoch 15/100
14727/14727 [==============================] - 2s 118us/step - loss: 19.4154 - counts/Sox2_loss: 0.9588 - counts/Oct4_loss: 0.9828 - val_loss: 19.7637 - val_counts/Sox2_loss: 0.9635 - val_counts/Oct4_loss: 1.0129
Epoch 16/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.4136 - counts/Sox2_loss: 0.9585 - counts/Oct4_loss: 0.9828 - val_loss: 19.8800 - val_counts/Sox2_loss: 0.9682 - val_counts/Oct4_loss: 1.0198
Epoch 17/100
14727/14727 [==============================] - 2s 122us/step - loss: 19.4840 - counts/Sox2_loss: 0.9628 - counts/Oct4_loss: 0.9856 - val_loss: 20.0021 - val_counts/Sox2_loss: 0.9780 - val_counts/Oct4_loss: 1.0222
Epoch 18/100
14727/14727 [==============================] - 2s 126us/step - loss: 19.4463 - counts/Sox2_loss: 0.9602 - counts/Oct4_loss: 0.9844 - val_loss: 19.7587 - val_counts/Sox2_loss: 0.9630 - val_counts/Oct4_loss: 1.0129
Epoch 19/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.3801 - counts/Sox2_loss: 0.9566 - counts/Oct4_loss: 0.9814 - val_loss: 19.7472 - val_counts/Sox2_loss: 0.9634 - val_counts/Oct4_loss: 1.0113
Epoch 20/100
14727/14727 [==============================] - 2s 117us/step - loss: 19.4474 - counts/Sox2_loss: 0.9608 - counts/Oct4_loss: 0.9840 - val_loss: 19.7764 - val_counts/Sox2_loss: 0.9628 - val_counts/Oct4_loss: 1.0149
Epoch 21/100
14727/14727 [==============================] - 2s 118us/step - loss: 19.3669 - counts/Sox2_loss: 0.9556 - counts/Oct4_loss: 0.9811 - val_loss: 19.7911 - val_counts/Sox2_loss: 0.9642 - val_counts/Oct4_loss: 1.0149
Epoch 22/100
14727/14727 [==============================] - 2s 125us/step - loss: 19.3586 - counts/Sox2_loss: 0.9540 - counts/Oct4_loss: 0.9818 - val_loss: 19.8988 - val_counts/Sox2_loss: 0.9687 - val_counts/Oct4_loss: 1.0212
Epoch 23/100
14727/14727 [==============================] - 2s 126us/step - loss: 19.3780 - counts/Sox2_loss: 0.9552 - counts/Oct4_loss: 0.9826 - val_loss: 19.7757 - val_counts/Sox2_loss: 0.9636 - val_counts/Oct4_loss: 1.0140
Epoch 24/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.3420 - counts/Sox2_loss: 0.9535 - counts/Oct4_loss: 0.9807 - val_loss: 19.8119 - val_counts/Sox2_loss: 0.9659 - val_counts/Oct4_loss: 1.0152
Epoch 25/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.3324 - counts/Sox2_loss: 0.9522 - counts/Oct4_loss: 0.9810 - val_loss: 19.8315 - val_counts/Sox2_loss: 0.9662 - val_counts/Oct4_loss: 1.0170
Epoch 26/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.3240 - counts/Sox2_loss: 0.9523 - counts/Oct4_loss: 0.9801 - val_loss: 19.7536 - val_counts/Sox2_loss: 0.9656 - val_counts/Oct4_loss: 1.0097
Epoch 27/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.3181 - counts/Sox2_loss: 0.9515 - counts/Oct4_loss: 0.9803 - val_loss: 20.0389 - val_counts/Sox2_loss: 0.9762 - val_counts/Oct4_loss: 1.0277
Epoch 28/100
14727/14727 [==============================] - 2s 124us/step - loss: 19.3220 - counts/Sox2_loss: 0.9517 - counts/Oct4_loss: 0.9805 - val_loss: 19.8652 - val_counts/Sox2_loss: 0.9671 - val_counts/Oct4_loss: 1.0194
Epoch 29/100
14727/14727 [==============================] - 2s 129us/step - loss: 19.3220 - counts/Sox2_loss: 0.9525 - counts/Oct4_loss: 0.9797 - val_loss: 19.7216 - val_counts/Sox2_loss: 0.9625 - val_counts/Oct4_loss: 1.0097
Epoch 30/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.2861 - counts/Sox2_loss: 0.9498 - counts/Oct4_loss: 0.9788 - val_loss: 19.7813 - val_counts/Sox2_loss: 0.9622 - val_counts/Oct4_loss: 1.0159
Epoch 31/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.2786 - counts/Sox2_loss: 0.9497 - counts/Oct4_loss: 0.9781 - val_loss: 19.8198 - val_counts/Sox2_loss: 0.9691 - val_counts/Oct4_loss: 1.0129
Epoch 32/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.2850 - counts/Sox2_loss: 0.9500 - counts/Oct4_loss: 0.9785 - val_loss: 19.7519 - val_counts/Sox2_loss: 0.9640 - val_counts/Oct4_loss: 1.0112
Epoch 33/100
14727/14727 [==============================] - 2s 126us/step - loss: 19.2742 - counts/Sox2_loss: 0.9498 - counts/Oct4_loss: 0.9776 - val_loss: 19.8655 - val_counts/Sox2_loss: 0.9660 - val_counts/Oct4_loss: 1.0206
Epoch 34/100
14727/14727 [==============================] - 2s 121us/step - loss: 19.2783 - counts/Sox2_loss: 0.9493 - counts/Oct4_loss: 0.9786 - val_loss: 19.9126 - val_counts/Sox2_loss: 0.9776 - val_counts/Oct4_loss: 1.0136
Epoch 35/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.2360 - counts/Sox2_loss: 0.9471 - counts/Oct4_loss: 0.9765 - val_loss: 19.7646 - val_counts/Sox2_loss: 0.9649 - val_counts/Oct4_loss: 1.0116
Epoch 36/100
14727/14727 [==============================] - 2s 120us/step - loss: 19.1806 - counts/Sox2_loss: 0.9430 - counts/Oct4_loss: 0.9751 - val_loss: 19.7319 - val_counts/Sox2_loss: 0.9639 - val_counts/Oct4_loss: 1.0093
Epoch 37/100
14727/14727 [==============================] - 2s 119us/step - loss: 19.2011 - counts/Sox2_loss: 0.9449 - counts/Oct4_loss: 0.9752 - val_loss: 19.8248 - val_counts/Sox2_loss: 0.9678 - val_counts/Oct4_loss: 1.0147
Epoch 38/100
14727/14727 [==============================] - 2s 125us/step - loss: 19.1842 - counts/Sox2_loss: 0.9436 - counts/Oct4_loss: 0.9748 - val_loss: 19.9188 - val_counts/Sox2_loss: 0.9721 - val_counts/Oct4_loss: 1.0198
Epoch 39/100
14727/14727 [==============================] - 2s 122us/step - loss: 19.1819 - counts/Sox2_loss: 0.9434 - counts/Oct4_loss: 0.9748 - val_loss: 20.0101 - val_counts/Sox2_loss: 0.9810 - val_counts/Oct4_loss: 1.0201
In [18]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[18]:
{'loss': 19.721617591054727,
 'counts/Sox2_loss': 0.9624763803468259,
 'counts/Oct4_loss': 1.0096853834309718}
In [19]:
from basepair.math import softmax
from basepair import samplers
from basepair.preproc import bin_counts
import numpy as np

class Seq2Sox2Oct4:

    def __init__(self, x, y, model):
        self.x = x
        self.y = y 
        self.model = model
        # Make the prediction
        self.y_pred = [softmax(y) for y in model.predict(x)]
        
    def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), fpath_template=None, binsize=1):
        import matplotlib.pyplot as plt
        if sort == 'random':
            idx_list = samplers.random(self.x, n)
        elif "_" in sort:
            kind, task = sort.split("_")
            #task_id = {"Sox2": 0, "Oct4": 1}[task]
            if kind == "max":
                idx_list = samplers.top_max_count(self.y[f"profile/{task}"], n)
            elif kind == "sum":
                idx_list = samplers.top_sum_count(self.y[f"profile/{task}"], n)
            else:
                raise ValueError("")
        else:
            raise ValueError(f"sort={sort} couldn't be interpreted")
            
        # for visualization, we use bucketize
        for i, idx in enumerate(idx_list):
            fig = plt.figure(figsize=figsize)
            plt.subplot(141)
            if i == 0:
                plt.title("Predicted Sox2")
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
            plt.legend()
            plt.subplot(142)
            if i == 0:
                plt.title("Observed Sox2")
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 1])))
            plt.legend()
            plt.subplot(143)
            if i == 0:
                plt.title("Predicted Oct4")
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[1][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[1][idx, :, 1])))
            plt.legend()
            plt.subplot(144)
            if i == 0:
                plt.title("Observed Oct4")
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 1])))
            plt.legend()
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [20]:
y_pred = model.predict(test[0])
In [21]:
from basepair.plots import regression_eval
In [22]:
regression_eval(test[1]['counts/Sox2'].mean(-1), y_pred[ds.task2idx("Sox2", 'counts')-2].mean(-1))
In [23]:
regression_eval(test[1]['counts/Oct4'].mean(-1), y_pred[ds.task2idx("Oct4", 'counts')-2].mean(-1))
In [24]:
test[2].head()
Out[24]:
id chr start end task
0 0 chr1 74957233 74958233 Sox2
5 5 chr1 189805323 189806323 Sox2
13 13 chr9 61249302 61250302 Oct4
18 18 chr1 35220119 35221119 Oct4
19 19 chr8 124810472 124811472 Oct4