Goal

  • train a decent seq-> profile + counts model for ChIP-seq

Resources

  • washu
    • session: 7-w-sox2-oct4-chipseq
In [7]:
# Use gpus 0,1
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
import time
from pathlib import Path
import sys
sys.path.append(str(Path(os.getcwd()).absolute().parent.parent))
sys.path.append('/opt/miniconda3/envs/basepair/lib/python3.6/site-packages')
In [8]:
import basepair
In [9]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
In [11]:
ddir = '/home/prime/data'
In [12]:
bdir = "/data/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/data/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [13]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [26]:
# Get the training data
start = time.time()
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
end = time.time() - start
print('Time taken: ' + str(end))
2018-11-14 22:28:46,176 [INFO] extract sequence
2018-11-14 22:28:48,050 [INFO] extract counts
100%|██████████| 2/2 [00:11<00:00,  5.77s/it]
Time taken: 15.623735189437866
In [27]:
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d

def smooth_profile(profile, binsize=10, mode='gauss', sigma=1.2):
    def binify(arr):
        
        L, C = arr.shape
        assert (L % binsize == 0 and L > binsize)
        out = np.zeros((L // binsize, C))
        for i in range(out.shape[0]):
            out[i] = (arr[i * binsize: (i + 1) * binsize, :]).sum(axis=0)
        return out
        
    profile = binify(profile)
    profile[:, 0] = gaussian_filter1d(profile[:, 0], sigma=sigma)
    profile[:, 1] = gaussian_filter1d(profile[:, 1], sigma=sigma)
    return profile

def update_dataset(dataset, binsize=10, mode='gauss'):
    N, L, C = dataset.shape
    new_dataset = np.zeros((N, L // binsize, C))
    for i in range(N):
        new_dataset[i] = smooth_profile(dataset[i, :, :], binsize, mode)
    return new_dataset

for task in ['profile/Sox2', 'profile/Oct4']:
    start = time.time()
    train[1][task] = update_dataset(train[1][task])
    end = time.time() - start
    print ('%d seconds for smoothing training for %s' % (end, task))
    valid[1][task] = update_dataset(valid[1][task])
    test[1][task] = update_dataset(test[1][task])
    print (train[1][task].shape)
8 seconds for smoothing training
(14727, 100, 2)
8 seconds for smoothing training
(14727, 100, 2)

setup a new model having two output branches

In [28]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
In [75]:
import keras.layers as kl
from keras.optimizers import Adam, SGD
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [30]:
i=1
In [31]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file

train the model

  • add the training curve plot bellow training
In [88]:
# TODO - play around with this
def seq_multitask_chipseq(filters=21,
                          conv1_kernel_size=21,
                          tconv_kernel_size=25,
                          #tconv_kernel_size2=25,
                          n_dil_layers=6,
                          lr=0.004,
                          decay=0,
                          c_task_weight=100,
                          use_profile=True,
                          use_counts=True,
                          tasks=['sox2', 'oct4'],
                          seq_len=1000,
                          pool_size=10):
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
    """
    # TODO - build the reverse complement symmetry into the model
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    prev_layers = [first_conv]
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = kl.add(prev_layers)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(2*len(tasks), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(len(tasks), kernel_size=(tconv_kernel_size2, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(int(len(tasks)/2), kernel_size=(tconv_kernel_size3, 1), padding='same')(x)
    out = kl.Reshape((-1, 2 * len(tasks)))(x)

    if pool_size:
        out = kl.AvgPool1D(pool_size=pool_size, padding='valid')(out)
    
    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        output = [kl.Lambda(lambda x, i: x[:, :, (2 * i):(2 * i + 2)],
                            output_shape=(seq_len, 2),
                            name="profile/" + task,
                            arguments={"i": i})(out)
                  for i, task in enumerate(tasks)]
        
        # true counts size is (tasks, 1000, 2)
        outputs += output
        losses += [twochannel_multinomial_nll] * len(tasks)
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        counts = [kl.Dense(2, name="counts/" + task)(pooled)
                  for task in tasks]
        outputs += counts
        losses += ["mse"] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    model = Model(inp, outputs)
    optimizer = Adam(lr=lr, decay=decay)

    model.compile(optimizer, loss=losses, loss_weights=loss_weights)
    return model
In [97]:
from basepair.models import seq_multitask
# hyper-parameters
mfn = "seq_multitask_chipseq"
use_profile = True
use_counts = False
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=100,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               seq_len=train[0].shape[1],
               lr=0.004,
               decay=0,
               pool_size=10)
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
In [98]:
import numpy as np
np.random.seed(20)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0], 
                    train[1],
          batch_size=256, 
          epochs=25,
          validation_data=valid[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
Train on 14727 samples, validate on 4493 samples
Epoch 1/25
14727/14727 [==============================] - 12s 800us/step - loss: 251.2555 - profile/Sox2_loss: 78.8574 - profile/Oct4_loss: 172.3981 - val_loss: 238.6741 - val_profile/Sox2_loss: 76.6722 - val_profile/Oct4_loss: 162.0019
Epoch 2/25
14727/14727 [==============================] - 8s 543us/step - loss: 226.4863 - profile/Sox2_loss: 74.8359 - profile/Oct4_loss: 151.6503 - val_loss: 227.1097 - val_profile/Sox2_loss: 74.7448 - val_profile/Oct4_loss: 152.3649
Epoch 3/25
14727/14727 [==============================] - 8s 537us/step - loss: 220.6812 - profile/Sox2_loss: 73.8036 - profile/Oct4_loss: 146.8775 - val_loss: 225.6369 - val_profile/Sox2_loss: 74.5056 - val_profile/Oct4_loss: 151.1313
Epoch 4/25
14727/14727 [==============================] - 8s 537us/step - loss: 218.6760 - profile/Sox2_loss: 73.3532 - profile/Oct4_loss: 145.3229 - val_loss: 223.3579 - val_profile/Sox2_loss: 74.0898 - val_profile/Oct4_loss: 149.2681
Epoch 5/25
14727/14727 [==============================] - 8s 538us/step - loss: 217.1702 - profile/Sox2_loss: 73.0013 - profile/Oct4_loss: 144.1689 - val_loss: 222.2848 - val_profile/Sox2_loss: 73.9115 - val_profile/Oct4_loss: 148.3733
Epoch 6/25
14727/14727 [==============================] - 8s 539us/step - loss: 216.2657 - profile/Sox2_loss: 72.7864 - profile/Oct4_loss: 143.4793 - val_loss: 221.7256 - val_profile/Sox2_loss: 73.8658 - val_profile/Oct4_loss: 147.8599
Epoch 7/25
14727/14727 [==============================] - 8s 538us/step - loss: 215.5402 - profile/Sox2_loss: 72.5929 - profile/Oct4_loss: 142.9473 - val_loss: 220.3350 - val_profile/Sox2_loss: 73.3243 - val_profile/Oct4_loss: 147.0107
Epoch 8/25
14727/14727 [==============================] - 8s 541us/step - loss: 214.7559 - profile/Sox2_loss: 72.4317 - profile/Oct4_loss: 142.3242 - val_loss: 219.4419 - val_profile/Sox2_loss: 73.2325 - val_profile/Oct4_loss: 146.2094
Epoch 9/25
14727/14727 [==============================] - 8s 541us/step - loss: 214.2230 - profile/Sox2_loss: 72.3176 - profile/Oct4_loss: 141.9054 - val_loss: 219.2628 - val_profile/Sox2_loss: 73.2070 - val_profile/Oct4_loss: 146.0557
Epoch 10/25
14727/14727 [==============================] - 8s 540us/step - loss: 213.5571 - profile/Sox2_loss: 72.1762 - profile/Oct4_loss: 141.3808 - val_loss: 218.9824 - val_profile/Sox2_loss: 73.1907 - val_profile/Oct4_loss: 145.7916
Epoch 11/25
14727/14727 [==============================] - 8s 542us/step - loss: 213.1677 - profile/Sox2_loss: 72.0822 - profile/Oct4_loss: 141.0855 - val_loss: 218.6273 - val_profile/Sox2_loss: 73.0634 - val_profile/Oct4_loss: 145.5640
Epoch 12/25
14727/14727 [==============================] - 8s 541us/step - loss: 212.8714 - profile/Sox2_loss: 72.0292 - profile/Oct4_loss: 140.8421 - val_loss: 218.8246 - val_profile/Sox2_loss: 73.1754 - val_profile/Oct4_loss: 145.6492
Epoch 13/25
14727/14727 [==============================] - 8s 542us/step - loss: 212.4273 - profile/Sox2_loss: 71.9447 - profile/Oct4_loss: 140.4825 - val_loss: 218.1817 - val_profile/Sox2_loss: 73.0801 - val_profile/Oct4_loss: 145.1016
Epoch 14/25
14727/14727 [==============================] - 8s 543us/step - loss: 211.9226 - profile/Sox2_loss: 71.8459 - profile/Oct4_loss: 140.0767 - val_loss: 218.1905 - val_profile/Sox2_loss: 73.0889 - val_profile/Oct4_loss: 145.1016
Epoch 15/25
14727/14727 [==============================] - 8s 542us/step - loss: 211.6680 - profile/Sox2_loss: 71.8138 - profile/Oct4_loss: 139.8542 - val_loss: 218.4847 - val_profile/Sox2_loss: 73.1131 - val_profile/Oct4_loss: 145.3716
Epoch 16/25
14727/14727 [==============================] - 8s 543us/step - loss: 211.4565 - profile/Sox2_loss: 71.8032 - profile/Oct4_loss: 139.6533 - val_loss: 217.9530 - val_profile/Sox2_loss: 73.0733 - val_profile/Oct4_loss: 144.8797
Epoch 17/25
14727/14727 [==============================] - 8s 542us/step - loss: 211.0155 - profile/Sox2_loss: 71.6604 - profile/Oct4_loss: 139.3551 - val_loss: 218.1094 - val_profile/Sox2_loss: 73.0679 - val_profile/Oct4_loss: 145.0415
Epoch 18/25
14727/14727 [==============================] - 8s 543us/step - loss: 210.6677 - profile/Sox2_loss: 71.5891 - profile/Oct4_loss: 139.0785 - val_loss: 218.1345 - val_profile/Sox2_loss: 73.0133 - val_profile/Oct4_loss: 145.1212
Epoch 19/25
14727/14727 [==============================] - 8s 544us/step - loss: 210.5158 - profile/Sox2_loss: 71.5791 - profile/Oct4_loss: 138.9367 - val_loss: 217.8914 - val_profile/Sox2_loss: 73.0415 - val_profile/Oct4_loss: 144.8499
Epoch 20/25
14727/14727 [==============================] - 8s 543us/step - loss: 210.2952 - profile/Sox2_loss: 71.5189 - profile/Oct4_loss: 138.7763 - val_loss: 218.2774 - val_profile/Sox2_loss: 73.2038 - val_profile/Oct4_loss: 145.0736
Epoch 21/25
14727/14727 [==============================] - 8s 545us/step - loss: 210.0039 - profile/Sox2_loss: 71.4639 - profile/Oct4_loss: 138.5400 - val_loss: 217.9609 - val_profile/Sox2_loss: 72.9978 - val_profile/Oct4_loss: 144.9631
Epoch 22/25
14727/14727 [==============================] - 8s 545us/step - loss: 209.8051 - profile/Sox2_loss: 71.4172 - profile/Oct4_loss: 138.3879 - val_loss: 218.5823 - val_profile/Sox2_loss: 73.2451 - val_profile/Oct4_loss: 145.3371
Epoch 23/25
14727/14727 [==============================] - 8s 544us/step - loss: 209.6246 - profile/Sox2_loss: 71.3922 - profile/Oct4_loss: 138.2324 - val_loss: 218.0003 - val_profile/Sox2_loss: 73.1456 - val_profile/Oct4_loss: 144.8547
Epoch 24/25
14727/14727 [==============================] - 8s 544us/step - loss: 209.5783 - profile/Sox2_loss: 71.3609 - profile/Oct4_loss: 138.2174 - val_loss: 217.7584 - val_profile/Sox2_loss: 73.0382 - val_profile/Oct4_loss: 144.7203
Epoch 25/25
14727/14727 [==============================] - 8s 543us/step - loss: 209.3527 - profile/Sox2_loss: 71.3142 - profile/Oct4_loss: 138.0386 - val_loss: 218.1531 - val_profile/Sox2_loss: 73.0616 - val_profile/Oct4_loss: 145.0915
In [72]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[72]:
{'loss': 219.49226792702927,
 'profile/Sox2_loss': 73.21527601047212,
 'profile/Oct4_loss': 146.27699150392797}
In [49]:
from basepair.math import softmax
from basepair import samplers
from basepair.preproc import bin_counts
import numpy as np

class Seq2Sox2Oct4:

    def __init__(self, x, y, model):
        self.x = x
        self.y = y 
        self.model = model
        # Make the prediction
        self.y_pred = [softmax(y) for y in model.predict(x)]
        
    def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), fpath_template=None, binsize=1):
        import matplotlib.pyplot as plt
        if sort == 'random':
            idx_list = samplers.random(self.x, n)
        elif "_" in sort:
            kind, task = sort.split("_")
            #task_id = {"Sox2": 0, "Oct4": 1}[task]
            if kind == "max":
                idx_list = samplers.top_max_count(self.y[f"profile/{task}"], n)
            elif kind == "sum":
                idx_list = samplers.top_sum_count(self.y[f"profile/{task}"], n)
            else:
                raise ValueError("")
        else:
            raise ValueError(f"sort={sort} couldn't be interpreted")
            
        # for visualization, we use bucketize
        for i, idx in enumerate(idx_list):
            fig = plt.figure(figsize=figsize)
            plt.subplot(141)
            if i == 0:
                plt.title("Predicted Sox2")
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
            plt.legend()
            plt.subplot(142)
            if i == 0:
                plt.title("Observed Sox2")
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 1])))
            plt.legend()
            plt.subplot(143)
            if i == 0:
                plt.title("Predicted Oct4")
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[1][idx, :, 0])))
            plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[1][idx, :, 1])))
            plt.legend()
            plt.subplot(144)
            if i == 0:
                plt.title("Observed Oct4")
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 0])))
            plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 1])))
            plt.legend()
            if fpath_template is not None:
                plt.savefig(fpath_template.format(i) + '.png', dpi=600)
                plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
                plt.close(fig)    # close the figure
                show_figure(fig)
                plt.show()
In [99]:
y_pred = model.predict(test[0])
In [ ]:
from basepair.plots import regression_eval
In [ ]:
regression_eval(test[1]['counts/Sox2'].mean(-1), y_pred[ds.task2idx("Sox2", 'counts')].mean(-1))
In [ ]:
regression_eval(test[1]['counts/Oct4'].mean(-1), y_pred[ds.task2idx("Oct4", 'counts')].mean(-1))
In [100]:
pl = Seq2Sox2Oct4(test[0], test[1], model)
In [101]:
pl.plot(n=10, sort='sum_Sox2', binsize=1)
In [ ]:
from basepair.BPNet import BPNetPredictor
In [ ]:
bpnet = BPNetPredictor(model, ds.fasta_file, list(ds.task_specs), preproc=preproc)
In [ ]:
test[2].head()
In [ ]:
from pybedtools import Interval, BedTool
In [ ]:
bt = BedTool.from_dataframe(test[2][["chr", "start", "end"]][:5])
In [ ]:
# For some intervals from the genome, plot the observed and predicted profiles
In [ ]:
bpnet.predict_plot(intervals=list(bt), bws = ds2bws(ds), profile_grad="weighted")