Goal

  • train a decent seq-> profile + counts model for ChIP-seq

Resources

  • washu
    • session: 7-w-sox2-oct4-chipseq
In [7]:
# Use gpus 0,1
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
import time
from pathlib import Path
import sys
sys.path.append(str(Path(os.getcwd()).absolute().parent.parent))
sys.path.append('/opt/miniconda3/envs/basepair/lib/python3.6/site-packages')
In [8]:
import basepair
In [9]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
In [11]:
ddir = '/home/prime/data'
In [12]:
bdir = "/data/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/data/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [13]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [ ]:
# Get the training data
start = time.time()
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
end = time.time() - start
print('Time taken: ' + str(end))
2018-11-14 22:07:33,062 [INFO] extract sequence
In [ ]:
train[1]['profile/Sox2'].shape
In [24]:
# TODO - play around with this
def seq_multitask_chipseq(filters=21,
                          conv1_kernel_size=21,
                          tconv_kernel_size=25,
                          #tconv_kernel_size2=25,
                          n_dil_layers=6,
                          lr=0.004,
                          c_task_weight=100,
                          use_profile=True,
                          use_counts=True,
                          tasks=['sox2', 'oct4'],
                          seq_len=201):
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
    """
    # TODO - build the reverse complement symmetry into the model
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    prev_layers = [first_conv]
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = kl.add(prev_layers)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(2*len(tasks), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(len(tasks), kernel_size=(tconv_kernel_size2, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(int(len(tasks)/2), kernel_size=(tconv_kernel_size3, 1), padding='same')(x)
    out = kl.Reshape((-1, 2 * len(tasks)))(x)

    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        output = [kl.Lambda(lambda x, i: x[:, :, (2 * i):(2 * i + 2)],
                            output_shape=(seq_len, 2),
                            name="profile/" + task,
                            arguments={"i": i})(out)
                  for i, task in enumerate(tasks)]
        
        # true counts size is (tasks, 1000, 2)
        outputs += output
        losses += [twochannel_multinomial_nll] * len(tasks)
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        counts = [kl.Dense(2, name="counts/" + task)(pooled)
                  for task in tasks]
        outputs += counts
        losses += ["mse"] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    model = Model(inp, outputs)
    model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
    return model

setup a new model having two output branches

In [25]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
In [26]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [27]:
i=1
In [28]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file

train the model

  • add the training curve plot bellow training
In [29]:
# hyper-parameters
mfn = "seq_multitask_chipseq"
use_profile = True
use_counts = False
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=100,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               seq_len=train[0].shape[1],
               lr=0.004)
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
In [30]:
import numpy as np
np.random.seed(20)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0], 
                    train[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
Train on 14727 samples, validate on 4493 samples
Epoch 1/100
14727/14727 [==============================] - 57s 4ms/step - loss: 588.3137 - profile/Sox2_loss: 167.3501 - profile/Oct4_loss: 420.9636 - val_loss: 580.2910 - val_profile/Sox2_loss: 165.8053 - val_profile/Oct4_loss: 414.4857
Epoch 2/100
14727/14727 [==============================] - 8s 536us/step - loss: 561.3448 - profile/Sox2_loss: 162.9528 - profile/Oct4_loss: 398.3920 - val_loss: 566.5297 - val_profile/Sox2_loss: 163.6218 - val_profile/Oct4_loss: 402.9079
Epoch 3/100
14727/14727 [==============================] - 8s 536us/step - loss: 554.8441 - profile/Sox2_loss: 161.7578 - profile/Oct4_loss: 393.0863 - val_loss: 564.2947 - val_profile/Sox2_loss: 163.2263 - val_profile/Oct4_loss: 401.0684
Epoch 4/100
14727/14727 [==============================] - 8s 537us/step - loss: 552.5118 - profile/Sox2_loss: 161.2581 - profile/Oct4_loss: 391.2537 - val_loss: 562.4303 - val_profile/Sox2_loss: 162.8647 - val_profile/Oct4_loss: 399.5657
Epoch 5/100
14727/14727 [==============================] - 8s 537us/step - loss: 551.3399 - profile/Sox2_loss: 160.9652 - profile/Oct4_loss: 390.3747 - val_loss: 563.0924 - val_profile/Sox2_loss: 162.8784 - val_profile/Oct4_loss: 400.2141
Epoch 6/100
14727/14727 [==============================] - 8s 537us/step - loss: 550.5449 - profile/Sox2_loss: 160.8030 - profile/Oct4_loss: 389.7419 - val_loss: 561.3682 - val_profile/Sox2_loss: 162.5769 - val_profile/Oct4_loss: 398.7913
Epoch 7/100
14727/14727 [==============================] - 8s 538us/step - loss: 549.5884 - profile/Sox2_loss: 160.5382 - profile/Oct4_loss: 389.0502 - val_loss: 560.9944 - val_profile/Sox2_loss: 162.4371 - val_profile/Oct4_loss: 398.5573
Epoch 8/100
14727/14727 [==============================] - 8s 539us/step - loss: 548.8029 - profile/Sox2_loss: 160.3406 - profile/Oct4_loss: 388.4623 - val_loss: 559.6991 - val_profile/Sox2_loss: 162.1395 - val_profile/Oct4_loss: 397.5596
Epoch 9/100
14727/14727 [==============================] - 8s 538us/step - loss: 548.1519 - profile/Sox2_loss: 160.1504 - profile/Oct4_loss: 388.0015 - val_loss: 559.2430 - val_profile/Sox2_loss: 162.0213 - val_profile/Oct4_loss: 397.2217
Epoch 10/100
14727/14727 [==============================] - 8s 540us/step - loss: 547.3458 - profile/Sox2_loss: 159.9642 - profile/Oct4_loss: 387.3815 - val_loss: 558.9180 - val_profile/Sox2_loss: 162.1047 - val_profile/Oct4_loss: 396.8132
Epoch 11/100
14727/14727 [==============================] - 8s 538us/step - loss: 547.0454 - profile/Sox2_loss: 159.8739 - profile/Oct4_loss: 387.1715 - val_loss: 558.6106 - val_profile/Sox2_loss: 161.9104 - val_profile/Oct4_loss: 396.7002
Epoch 12/100
14727/14727 [==============================] - 8s 539us/step - loss: 546.6044 - profile/Sox2_loss: 159.7603 - profile/Oct4_loss: 386.8441 - val_loss: 558.7073 - val_profile/Sox2_loss: 161.9030 - val_profile/Oct4_loss: 396.8043
Epoch 13/100
14727/14727 [==============================] - 8s 540us/step - loss: 546.4705 - profile/Sox2_loss: 159.7316 - profile/Oct4_loss: 386.7389 - val_loss: 557.9292 - val_profile/Sox2_loss: 161.8015 - val_profile/Oct4_loss: 396.1277
Epoch 14/100
14727/14727 [==============================] - 8s 541us/step - loss: 545.8430 - profile/Sox2_loss: 159.5894 - profile/Oct4_loss: 386.2536 - val_loss: 557.8163 - val_profile/Sox2_loss: 161.7264 - val_profile/Oct4_loss: 396.0899
Epoch 15/100
14727/14727 [==============================] - 8s 541us/step - loss: 545.5738 - profile/Sox2_loss: 159.5321 - profile/Oct4_loss: 386.0417 - val_loss: 558.1341 - val_profile/Sox2_loss: 161.7891 - val_profile/Oct4_loss: 396.3450
Epoch 16/100
14727/14727 [==============================] - 8s 541us/step - loss: 545.2633 - profile/Sox2_loss: 159.4971 - profile/Oct4_loss: 385.7662 - val_loss: 557.7888 - val_profile/Sox2_loss: 161.9035 - val_profile/Oct4_loss: 395.8853
Epoch 17/100
14727/14727 [==============================] - 8s 540us/step - loss: 544.9362 - profile/Sox2_loss: 159.3860 - profile/Oct4_loss: 385.5501 - val_loss: 557.6650 - val_profile/Sox2_loss: 161.7854 - val_profile/Oct4_loss: 395.8795
Epoch 18/100
14727/14727 [==============================] - 8s 542us/step - loss: 544.8210 - profile/Sox2_loss: 159.3489 - profile/Oct4_loss: 385.4721 - val_loss: 557.4584 - val_profile/Sox2_loss: 161.7004 - val_profile/Oct4_loss: 395.7581
Epoch 19/100
14727/14727 [==============================] - 8s 543us/step - loss: 544.4866 - profile/Sox2_loss: 159.3221 - profile/Oct4_loss: 385.1645 - val_loss: 557.4753 - val_profile/Sox2_loss: 161.8241 - val_profile/Oct4_loss: 395.6512
Epoch 20/100
14727/14727 [==============================] - 8s 543us/step - loss: 544.0890 - profile/Sox2_loss: 159.1974 - profile/Oct4_loss: 384.8916 - val_loss: 557.5801 - val_profile/Sox2_loss: 161.7531 - val_profile/Oct4_loss: 395.8270
Epoch 21/100
14727/14727 [==============================] - 8s 544us/step - loss: 543.7610 - profile/Sox2_loss: 159.1372 - profile/Oct4_loss: 384.6237 - val_loss: 557.2528 - val_profile/Sox2_loss: 161.7513 - val_profile/Oct4_loss: 395.5014
Epoch 22/100
14727/14727 [==============================] - 8s 544us/step - loss: 543.6001 - profile/Sox2_loss: 159.0847 - profile/Oct4_loss: 384.5154 - val_loss: 557.9419 - val_profile/Sox2_loss: 161.9925 - val_profile/Oct4_loss: 395.9495
Epoch 23/100
14727/14727 [==============================] - 8s 544us/step - loss: 543.5109 - profile/Sox2_loss: 159.0871 - profile/Oct4_loss: 384.4238 - val_loss: 558.0156 - val_profile/Sox2_loss: 161.9692 - val_profile/Oct4_loss: 396.0464
Epoch 24/100
14727/14727 [==============================] - 8s 544us/step - loss: 543.2256 - profile/Sox2_loss: 159.0091 - profile/Oct4_loss: 384.2165 - val_loss: 557.0675 - val_profile/Sox2_loss: 161.7348 - val_profile/Oct4_loss: 395.3327
Epoch 25/100
14727/14727 [==============================] - 8s 545us/step - loss: 543.0635 - profile/Sox2_loss: 158.9751 - profile/Oct4_loss: 384.0884 - val_loss: 557.4430 - val_profile/Sox2_loss: 161.7478 - val_profile/Oct4_loss: 395.6952
Epoch 26/100
14727/14727 [==============================] - 8s 546us/step - loss: 542.7712 - profile/Sox2_loss: 158.9071 - profile/Oct4_loss: 383.8641 - val_loss: 557.8977 - val_profile/Sox2_loss: 162.0949 - val_profile/Oct4_loss: 395.8028
Epoch 27/100
14727/14727 [==============================] - 8s 546us/step - loss: 542.7368 - profile/Sox2_loss: 158.9135 - profile/Oct4_loss: 383.8233 - val_loss: 557.5992 - val_profile/Sox2_loss: 161.8781 - val_profile/Oct4_loss: 395.7211
Epoch 28/100
14727/14727 [==============================] - 8s 546us/step - loss: 542.4012 - profile/Sox2_loss: 158.8555 - profile/Oct4_loss: 383.5458 - val_loss: 558.1440 - val_profile/Sox2_loss: 161.8703 - val_profile/Oct4_loss: 396.2736
Epoch 29/100
14727/14727 [==============================] - 8s 547us/step - loss: 542.5821 - profile/Sox2_loss: 158.8776 - profile/Oct4_loss: 383.7045 - val_loss: 558.2814 - val_profile/Sox2_loss: 161.8959 - val_profile/Oct4_loss: 396.3855
In [31]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[31]:
{'loss': 557.0675261697232,
 'profile/Sox2_loss': 161.73476931726697,
 'profile/Oct4_loss': 395.3327579392162}
In [32]:
from basepair.math import softmax
from basepair import samplers
from basepair.preproc import bin_counts
import numpy as np

class Seq2Sox2Oct4:

    def __init__(self, x, y, model):
        self.x = x
        self.y = y 
        self.model = model
        # Make the prediction
        self.y_pred = [softmax(y) for y in model.predict(x)]
        
    def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), fpath_template=None, binsize=1):
        import matplotlib.pyplot as plt
        if sort == 'random':
            idx_list = samplers.random(self.x, n)
        elif "_" in sort:
            kind, task = sort.split("_")
            #task_id = {"Sox2": 0, "Oct4": 1}[task]
            if kind == "max":
                idx_list = samplers.top_max_count(self.y[f"profile/{task}"], n)
            elif kind == "sum":
                idx_list = samplers.top_sum_count(self.y[f"profile/{task}"], n)
            else:
                raise ValueError("")
        else:
            raise ValueError(f"sort={sort} couldn't be interpreted")
            
        # for visualization, we use bucketize
        for i, idx in enumerate(idx_list):
            fig = plt.figure(figsize=figsize)
            plt.subplot(141)
            if i == 0:
                plt.title("Predicted Sox2")
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
            plt.legend()
            plt.subplot(142)
            if i == 0:
                plt.title("Observed Sox2")
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 1])))
            plt.legend()
            plt.subplot(143)
            if i == 0:
                plt.title("Predicted Oct4")
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[1][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[1][idx, :, 1])))
            plt.legend()
            plt.subplot(144)
            if i == 0:
                plt.title("Observed Oct4")
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 1])))
            plt.legend()
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [ ]:
y_pred = model.predict(test[0])
In [ ]:
from basepair.plots import regression_eval
In [ ]:
regression_eval(test[1]['counts/Sox2'].mean(-1), y_pred[ds.task2idx("Sox2", 'counts')].mean(-1))
In [ ]:
regression_eval(test[1]['counts/Oct4'].mean(-1), y_pred[ds.task2idx("Oct4", 'counts')].mean(-1))
In [ ]:
pl = Seq2Sox2Oct4(test[0], test[1], model)
In [ ]:
pl.plot(n=10, sort='sum_Sox2', binsize=10)
In [ ]:
from basepair.BPNet import BPNetPredictor
In [ ]:
bpnet = BPNetPredictor(model, ds.fasta_file, list(ds.task_specs), preproc=preproc)
In [ ]:
test[2].head()
In [ ]:
from pybedtools import Interval, BedTool
In [ ]:
bt = BedTool.from_dataframe(test[2][["chr", "start", "end"]][:5])
In [ ]:
# For some intervals from the genome, plot the observed and predicted profiles
In [ ]:
bpnet.predict_plot(intervals=list(bt), bws = ds2bws(ds), profile_grad="weighted")