In [1]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
Using TensorFlow backend.
2018-07-29 13:51:12,469 [WARNING] git-lfs not installed
2018-07-29 13:51:12,489 [WARNING] git-lfs not installed
In [2]:
# Use gpus 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
In [3]:
ddir = get_data_dir()
In [4]:
bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [5]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [6]:
train, valid, test = chip_exo_nexus(ds, peak_width=201)
100%|██████████| 23474/23474 [00:00<00:00, 699085.41it/s]
2018-07-29 13:51:15,048 [INFO] extract sequence
2018-07-29 13:51:20,993 [INFO] extract counts
100%|██████████| 2/2 [00:10<00:00,  5.32s/it]
In [7]:
# Pre-process the data
preproc = AppendTotalCounts()
preproc.fit(train[1])
train[1] = preproc.transform(train[1])
valid[1] = preproc.transform(valid[1])
test[1] = preproc.transform(test[1])
In [8]:
def seq_multitask_chipseq(filters=21,
                          conv1_kernel_size=21,
                          tconv_kernel_size=25,
                          #tconv_kernel_size2=25,
                          n_dil_layers=6,
                          lr=0.004,
                          c_task_weight=100,
                          use_profile=True,
                          use_counts=True,
                          tasks=['sox2', 'oct4'],
                          seq_len=201):
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
    """
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    prev_layers = [first_conv]
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = kl.add(prev_layers)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(2*len(tasks), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(len(tasks), kernel_size=(tconv_kernel_size2, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(int(len(tasks)/2), kernel_size=(tconv_kernel_size3, 1), padding='same')(x)
    out = kl.Reshape((-1, 2 * len(tasks)))(x)

    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        output = [kl.Lambda(lambda x, i: x[:, :, (2 * i):(2 * i + 2)],
                            output_shape=(seq_len, 2),
                            name="profile/" + task,
                            arguments={"i": i})(out)
                  for i, task in enumerate(tasks)]
        outputs += output
        losses += [twochannel_multinomial_nll] * len(tasks)
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        counts = [kl.Dense(2, name="counts/" + task)(pooled)
                  for task in tasks]
        outputs += counts
        losses += ["mse"] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    model = Model(inp, outputs)
    model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
    return model
In [9]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
In [10]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [11]:
i=1
In [12]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file
In [13]:
# hyper-parameters
mfn = "seq_multitask_chipseq"
use_profile = True
use_counts = True
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=100,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               seq_len=train[0].shape[1],
               lr=0.004)
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
In [14]:
import numpy as np
np.random.seed(20)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0], 
                    train[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-07-29 13:51:32,679 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-07-29 13:51:40,021 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
Train on 14727 samples, validate on 4493 samples
Epoch 1/100
14727/14727 [==============================] - 13s 864us/step - loss: 211.7251 - profile/Sox2_loss: 50.6599 - profile/Oct4_loss: 138.3460 - counts/Sox2_loss: 1.0837 - counts/Oct4_loss: 1.1882 - val_loss: 209.7435 - val_profile/Sox2_loss: 50.6220 - val_profile/Oct4_loss: 138.6813 - val_counts/Sox2_loss: 1.0134 - val_counts/Oct4_loss: 1.0306
Epoch 2/100
14727/14727 [==============================] - 2s 156us/step - loss: 206.0029 - profile/Sox2_loss: 50.1324 - profile/Oct4_loss: 135.8749 - counts/Sox2_loss: 0.9997 - counts/Oct4_loss: 0.9998 - val_loss: 209.5826 - val_profile/Sox2_loss: 50.5911 - val_profile/Oct4_loss: 138.5633 - val_counts/Sox2_loss: 1.0126 - val_counts/Oct4_loss: 1.0302
Epoch 3/100
14727/14727 [==============================] - 2s 151us/step - loss: 205.8650 - profile/Sox2_loss: 50.1048 - profile/Oct4_loss: 135.7741 - counts/Sox2_loss: 0.9990 - counts/Oct4_loss: 0.9996 - val_loss: 209.5508 - val_profile/Sox2_loss: 50.5772 - val_profile/Oct4_loss: 138.5548 - val_counts/Sox2_loss: 1.0121 - val_counts/Oct4_loss: 1.0298
Epoch 4/100
14727/14727 [==============================] - 2s 151us/step - loss: 205.7071 - profile/Sox2_loss: 50.0772 - profile/Oct4_loss: 135.6528 - counts/Sox2_loss: 0.9983 - counts/Oct4_loss: 0.9994 - val_loss: 209.3357 - val_profile/Sox2_loss: 50.5631 - val_profile/Oct4_loss: 138.3662 - val_counts/Sox2_loss: 1.0109 - val_counts/Oct4_loss: 1.0297
Epoch 5/100
14727/14727 [==============================] - 2s 147us/step - loss: 205.6089 - profile/Sox2_loss: 50.0599 - profile/Oct4_loss: 135.5871 - counts/Sox2_loss: 0.9969 - counts/Oct4_loss: 0.9993 - val_loss: 209.3418 - val_profile/Sox2_loss: 50.5678 - val_profile/Oct4_loss: 138.3922 - val_counts/Sox2_loss: 1.0093 - val_counts/Oct4_loss: 1.0289
Epoch 6/100
14727/14727 [==============================] - 2s 147us/step - loss: 205.4422 - profile/Sox2_loss: 50.0346 - profile/Oct4_loss: 135.5083 - counts/Sox2_loss: 0.9922 - counts/Oct4_loss: 0.9978 - val_loss: 209.1191 - val_profile/Sox2_loss: 50.5431 - val_profile/Oct4_loss: 138.3010 - val_counts/Sox2_loss: 1.0016 - val_counts/Oct4_loss: 1.0259
Epoch 7/100
14727/14727 [==============================] - 2s 151us/step - loss: 205.2243 - profile/Sox2_loss: 50.0296 - profile/Oct4_loss: 135.4560 - counts/Sox2_loss: 0.9810 - counts/Oct4_loss: 0.9929 - val_loss: 208.7737 - val_profile/Sox2_loss: 50.5580 - val_profile/Oct4_loss: 138.2441 - val_counts/Sox2_loss: 0.9802 - val_counts/Oct4_loss: 1.0169
Epoch 8/100
14727/14727 [==============================] - 2s 153us/step - loss: 204.9289 - profile/Sox2_loss: 50.0205 - profile/Oct4_loss: 135.4003 - counts/Sox2_loss: 0.9671 - counts/Oct4_loss: 0.9837 - val_loss: 208.6394 - val_profile/Sox2_loss: 50.5371 - val_profile/Oct4_loss: 138.1773 - val_counts/Sox2_loss: 0.9758 - val_counts/Oct4_loss: 1.0167
Epoch 9/100
14727/14727 [==============================] - 2s 158us/step - loss: 204.6998 - profile/Sox2_loss: 49.9940 - profile/Oct4_loss: 135.3432 - counts/Sox2_loss: 0.9595 - counts/Oct4_loss: 0.9768 - val_loss: 208.5233 - val_profile/Sox2_loss: 50.5151 - val_profile/Oct4_loss: 138.1403 - val_counts/Sox2_loss: 0.9743 - val_counts/Oct4_loss: 1.0125
Epoch 10/100
14727/14727 [==============================] - 2s 155us/step - loss: 204.6080 - profile/Sox2_loss: 49.9793 - profile/Oct4_loss: 135.2777 - counts/Sox2_loss: 0.9609 - counts/Oct4_loss: 0.9742 - val_loss: 208.2380 - val_profile/Sox2_loss: 50.5073 - val_profile/Oct4_loss: 138.0394 - val_counts/Sox2_loss: 0.9661 - val_counts/Oct4_loss: 1.0031
Epoch 11/100
14727/14727 [==============================] - 2s 160us/step - loss: 204.3595 - profile/Sox2_loss: 49.9644 - profile/Oct4_loss: 135.2102 - counts/Sox2_loss: 0.9532 - counts/Oct4_loss: 0.9653 - val_loss: 207.8059 - val_profile/Sox2_loss: 50.5021 - val_profile/Oct4_loss: 137.8930 - val_counts/Sox2_loss: 0.9548 - val_counts/Oct4_loss: 0.9863
Epoch 12/100
14727/14727 [==============================] - 2s 165us/step - loss: 204.0296 - profile/Sox2_loss: 49.9486 - profile/Oct4_loss: 135.1119 - counts/Sox2_loss: 0.9485 - counts/Oct4_loss: 0.9484 - val_loss: 207.5683 - val_profile/Sox2_loss: 50.4825 - val_profile/Oct4_loss: 137.9090 - val_counts/Sox2_loss: 0.9481 - val_counts/Oct4_loss: 0.9695
Epoch 13/100
14727/14727 [==============================] - 2s 157us/step - loss: 203.4588 - profile/Sox2_loss: 49.9333 - profile/Oct4_loss: 135.0767 - counts/Sox2_loss: 0.9320 - counts/Oct4_loss: 0.9129 - val_loss: 207.0552 - val_profile/Sox2_loss: 50.4708 - val_profile/Oct4_loss: 137.9532 - val_counts/Sox2_loss: 0.9294 - val_counts/Oct4_loss: 0.9337
Epoch 14/100
14727/14727 [==============================] - 2s 159us/step - loss: 203.0511 - profile/Sox2_loss: 49.9173 - profile/Oct4_loss: 135.0544 - counts/Sox2_loss: 0.9199 - counts/Oct4_loss: 0.8881 - val_loss: 207.7127 - val_profile/Sox2_loss: 50.5156 - val_profile/Oct4_loss: 137.9227 - val_counts/Sox2_loss: 0.9605 - val_counts/Oct4_loss: 0.9669
Epoch 15/100
14727/14727 [==============================] - 2s 158us/step - loss: 202.6350 - profile/Sox2_loss: 49.8702 - profile/Oct4_loss: 135.0124 - counts/Sox2_loss: 0.9030 - counts/Oct4_loss: 0.8722 - val_loss: 206.5506 - val_profile/Sox2_loss: 50.4916 - val_profile/Oct4_loss: 137.8980 - val_counts/Sox2_loss: 0.9015 - val_counts/Oct4_loss: 0.9146
Epoch 16/100
14727/14727 [==============================] - 2s 157us/step - loss: 202.2473 - profile/Sox2_loss: 49.8467 - profile/Oct4_loss: 134.9953 - counts/Sox2_loss: 0.8808 - counts/Oct4_loss: 0.8598 - val_loss: 206.3836 - val_profile/Sox2_loss: 50.4584 - val_profile/Oct4_loss: 137.9531 - val_counts/Sox2_loss: 0.8879 - val_counts/Oct4_loss: 0.9094
Epoch 17/100
14727/14727 [==============================] - 2s 161us/step - loss: 202.1328 - profile/Sox2_loss: 49.8097 - profile/Oct4_loss: 134.9885 - counts/Sox2_loss: 0.8706 - counts/Oct4_loss: 0.8629 - val_loss: 207.3946 - val_profile/Sox2_loss: 50.5169 - val_profile/Oct4_loss: 138.0141 - val_counts/Sox2_loss: 0.9089 - val_counts/Oct4_loss: 0.9775
Epoch 18/100
14727/14727 [==============================] - 2s 154us/step - loss: 202.0516 - profile/Sox2_loss: 49.7982 - profile/Oct4_loss: 134.9872 - counts/Sox2_loss: 0.8595 - counts/Oct4_loss: 0.8672 - val_loss: 205.8398 - val_profile/Sox2_loss: 50.3968 - val_profile/Oct4_loss: 137.8694 - val_counts/Sox2_loss: 0.8625 - val_counts/Oct4_loss: 0.8948
Epoch 19/100
14727/14727 [==============================] - 2s 156us/step - loss: 201.7946 - profile/Sox2_loss: 49.7777 - profile/Oct4_loss: 134.9686 - counts/Sox2_loss: 0.8430 - counts/Oct4_loss: 0.8618 - val_loss: 205.7553 - val_profile/Sox2_loss: 50.4140 - val_profile/Oct4_loss: 137.9086 - val_counts/Sox2_loss: 0.8450 - val_counts/Oct4_loss: 0.8983
Epoch 20/100
14727/14727 [==============================] - 2s 163us/step - loss: 201.4055 - profile/Sox2_loss: 49.7362 - profile/Oct4_loss: 134.9607 - counts/Sox2_loss: 0.8178 - counts/Oct4_loss: 0.8531 - val_loss: 206.0948 - val_profile/Sox2_loss: 50.4116 - val_profile/Oct4_loss: 137.9285 - val_counts/Sox2_loss: 0.8703 - val_counts/Oct4_loss: 0.9051
Epoch 21/100
14727/14727 [==============================] - 2s 152us/step - loss: 201.0791 - profile/Sox2_loss: 49.7185 - profile/Oct4_loss: 134.9455 - counts/Sox2_loss: 0.7959 - counts/Oct4_loss: 0.8456 - val_loss: 206.0039 - val_profile/Sox2_loss: 50.3394 - val_profile/Oct4_loss: 138.0364 - val_counts/Sox2_loss: 0.8383 - val_counts/Oct4_loss: 0.9245
Epoch 22/100
14727/14727 [==============================] - 2s 157us/step - loss: 200.8895 - profile/Sox2_loss: 49.6700 - profile/Oct4_loss: 134.9124 - counts/Sox2_loss: 0.7891 - counts/Oct4_loss: 0.8417 - val_loss: 205.6796 - val_profile/Sox2_loss: 50.3912 - val_profile/Oct4_loss: 137.8233 - val_counts/Sox2_loss: 0.8369 - val_counts/Oct4_loss: 0.9096
Epoch 23/100
14727/14727 [==============================] - 2s 160us/step - loss: 200.6299 - profile/Sox2_loss: 49.6667 - profile/Oct4_loss: 134.8794 - counts/Sox2_loss: 0.7739 - counts/Oct4_loss: 0.8345 - val_loss: 205.7111 - val_profile/Sox2_loss: 50.3603 - val_profile/Oct4_loss: 137.8582 - val_counts/Sox2_loss: 0.8275 - val_counts/Oct4_loss: 0.9218
Epoch 24/100
14727/14727 [==============================] - 2s 168us/step - loss: 200.5265 - profile/Sox2_loss: 49.6516 - profile/Oct4_loss: 134.8283 - counts/Sox2_loss: 0.7720 - counts/Oct4_loss: 0.8327 - val_loss: 206.0415 - val_profile/Sox2_loss: 50.4086 - val_profile/Oct4_loss: 137.9434 - val_counts/Sox2_loss: 0.8287 - val_counts/Oct4_loss: 0.9403
Epoch 25/100
14727/14727 [==============================] - 2s 163us/step - loss: 200.3338 - profile/Sox2_loss: 49.6382 - profile/Oct4_loss: 134.8198 - counts/Sox2_loss: 0.7587 - counts/Oct4_loss: 0.8289 - val_loss: 205.7077 - val_profile/Sox2_loss: 50.4090 - val_profile/Oct4_loss: 137.8902 - val_counts/Sox2_loss: 0.8299 - val_counts/Oct4_loss: 0.9109
Epoch 26/100
14727/14727 [==============================] - 2s 155us/step - loss: 200.2845 - profile/Sox2_loss: 49.6482 - profile/Oct4_loss: 134.7832 - counts/Sox2_loss: 0.7596 - counts/Oct4_loss: 0.8257 - val_loss: 205.4978 - val_profile/Sox2_loss: 50.4299 - val_profile/Oct4_loss: 137.8975 - val_counts/Sox2_loss: 0.8066 - val_counts/Oct4_loss: 0.9105
Epoch 27/100
14727/14727 [==============================] - 2s 147us/step - loss: 199.9452 - profile/Sox2_loss: 49.6232 - profile/Oct4_loss: 134.7318 - counts/Sox2_loss: 0.7421 - counts/Oct4_loss: 0.8169 - val_loss: 205.9119 - val_profile/Sox2_loss: 50.4747 - val_profile/Oct4_loss: 137.9082 - val_counts/Sox2_loss: 0.8329 - val_counts/Oct4_loss: 0.9200
Epoch 28/100
14727/14727 [==============================] - 2s 157us/step - loss: 199.8546 - profile/Sox2_loss: 49.6170 - profile/Oct4_loss: 134.6927 - counts/Sox2_loss: 0.7381 - counts/Oct4_loss: 0.8164 - val_loss: 205.4987 - val_profile/Sox2_loss: 50.4194 - val_profile/Oct4_loss: 137.8582 - val_counts/Sox2_loss: 0.8163 - val_counts/Oct4_loss: 0.9058
Epoch 29/100
14727/14727 [==============================] - 2s 158us/step - loss: 199.5740 - profile/Sox2_loss: 49.6044 - profile/Oct4_loss: 134.7188 - counts/Sox2_loss: 0.7245 - counts/Oct4_loss: 0.8006 - val_loss: 205.3942 - val_profile/Sox2_loss: 50.4938 - val_profile/Oct4_loss: 137.8562 - val_counts/Sox2_loss: 0.8031 - val_counts/Oct4_loss: 0.9013
Epoch 30/100
14727/14727 [==============================] - 2s 153us/step - loss: 199.4693 - profile/Sox2_loss: 49.6107 - profile/Oct4_loss: 134.7522 - counts/Sox2_loss: 0.7170 - counts/Oct4_loss: 0.7936 - val_loss: 205.2851 - val_profile/Sox2_loss: 50.4802 - val_profile/Oct4_loss: 137.7654 - val_counts/Sox2_loss: 0.8024 - val_counts/Oct4_loss: 0.9016
Epoch 31/100
14727/14727 [==============================] - 2s 157us/step - loss: 199.2302 - profile/Sox2_loss: 49.5944 - profile/Oct4_loss: 134.6851 - counts/Sox2_loss: 0.7072 - counts/Oct4_loss: 0.7879 - val_loss: 205.5277 - val_profile/Sox2_loss: 50.5052 - val_profile/Oct4_loss: 137.7875 - val_counts/Sox2_loss: 0.8100 - val_counts/Oct4_loss: 0.9135
Epoch 32/100
14727/14727 [==============================] - 2s 156us/step - loss: 199.2301 - profile/Sox2_loss: 49.5698 - profile/Oct4_loss: 134.6291 - counts/Sox2_loss: 0.7147 - counts/Oct4_loss: 0.7884 - val_loss: 205.3241 - val_profile/Sox2_loss: 50.4561 - val_profile/Oct4_loss: 137.8210 - val_counts/Sox2_loss: 0.8040 - val_counts/Oct4_loss: 0.9007
Epoch 33/100
14727/14727 [==============================] - 2s 148us/step - loss: 198.6232 - profile/Sox2_loss: 49.5834 - profile/Oct4_loss: 134.5846 - counts/Sox2_loss: 0.6790 - counts/Oct4_loss: 0.7665 - val_loss: 205.3765 - val_profile/Sox2_loss: 50.4810 - val_profile/Oct4_loss: 137.8340 - val_counts/Sox2_loss: 0.8052 - val_counts/Oct4_loss: 0.9009
Epoch 34/100
14727/14727 [==============================] - 2s 151us/step - loss: 198.6048 - profile/Sox2_loss: 49.5697 - profile/Oct4_loss: 134.6148 - counts/Sox2_loss: 0.6817 - counts/Oct4_loss: 0.7604 - val_loss: 206.9184 - val_profile/Sox2_loss: 50.4923 - val_profile/Oct4_loss: 137.9622 - val_counts/Sox2_loss: 0.8856 - val_counts/Oct4_loss: 0.9608
Epoch 35/100
14727/14727 [==============================] - 2s 160us/step - loss: 198.5092 - profile/Sox2_loss: 49.5456 - profile/Oct4_loss: 134.6765 - counts/Sox2_loss: 0.6731 - counts/Oct4_loss: 0.7556 - val_loss: 205.4375 - val_profile/Sox2_loss: 50.5393 - val_profile/Oct4_loss: 137.7730 - val_counts/Sox2_loss: 0.8051 - val_counts/Oct4_loss: 0.9074
In [15]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[15]:
{'loss': 205.28515092148007,
 'profile/Sox2_loss': 50.48016872783825,
 'profile/Oct4_loss': 137.76537925640406,
 'counts/Sox2_loss': 0.8023729628532786,
 'counts/Oct4_loss': 0.9015873560469266}
In [16]:
BED_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
Sox2_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Sox2/"
Oct4_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Oct4/"
In [17]:
import pandas as pd
import numpy as np
from pybedtools import BedTool, Interval
from basepair.config import get_data_dir
from basepair.preproc import bin_counts
from concise.utils.helper import get_from_module
from tqdm import tqdm
from concise.preprocessing import encodeDNA
from random import Random
import joblib
from basepair.preproc import resize_interval
from genomelake.extractors import FastaExtractor, BigwigExtractor
from kipoi.data_utils import get_dataset_item
import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
In [18]:
def get_chipnexus_data(bed_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.narrowPeak",
                       peak_fasta_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.fasta",
                       bigwigs={"cuts_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
                                 "cuts_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
                                }
                       ):
    """Loads the dataframe for sox2
    """
    from concise.utils.fasta import read_fasta
    import pyBigWig

    fas = read_fasta(peak_fasta_file)

    bed = BedTool(bed_file)

    assert len(fas) == len(bed)

    bigwig_obj = {k: pyBigWig.open(v) for k, v in bigwigs.items()}

    # cuts_pos = []
    # cuts_neg = []
    l = []
    for interval in tqdm(bed):
        l.append({"chr": interval.chrom,
                  "start": interval.start,
                  "end": interval.stop,
                  **{k: np.nan_to_num(bw.values(interval.chrom,
                                                interval.start,
                                                interval.stop,
                                                numpy=True))
                     for k, bw in bigwig_obj.items()}
                  })

    dfc = pd.DataFrame(l)

    dfc['seq'] = list(fas.values())
    dfc['seq_id'] = list(fas)

    dfc['seq'] = dfc.seq.str.upper()
    return dfc
In [19]:
def sox2_oct4_peaks_sox2(valid_chr=['chr2', 'chr3', 'chr4'],
                         test_chr=['chr1', 'chr8', 'chr9']):
    """
    The default chromomsome split is roughly 60/20/20
    """
    for v in valid_chr:
        assert v not in test_chr

    dfc = get_chipnexus_data(
        bigwigs={"sox2_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
                 "sox2_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
                 "oct4_pos": f"{Oct4_BW_DIR}/Oct4_1234.pos.bw",
                 "oct4_neg": f"{Oct4_BW_DIR}/Oct4_1234.neg.bw"})
    seq = encodeDNA(dfc.seq)

    # Prepare the signal
    sox2_pos = np.stack(dfc.sox2_pos)
    sox2_neg = np.stack(dfc.sox2_neg)
    oct4_pos = np.stack(dfc.oct4_pos)
    oct4_neg = np.stack(dfc.oct4_neg)

    ids = dfc.seq_id

    # Data splits
    is_test = dfc.chr.isin(test_chr)
    is_valid = dfc.chr.isin(valid_chr)
    is_train = (~is_test) & (~is_valid)

    sox2_cuts = np.stack([sox2_pos, sox2_neg], axis=-1)
    oct4_cuts = np.stack([oct4_pos, oct4_neg], axis=-1)

    return tuple(((seq[subset],  # x
                   {"sox2": sox2_cuts[subset],  # y
                    "oct4": oct4_cuts[subset]},
                   dfc[subset])  # metadata
                  for subset in [is_train, is_valid, is_test]))
In [20]:
# hyper-parameters
from basepair.models import seq_multitask
mfn2 = "seq_multitask"
use_profile = True
use_counts = True
mkwargs2 = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=25,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               lr=0.004)
In [21]:
data2 = sox2_oct4_peaks_sox2()
100%|██████████| 9396/9396 [00:06<00:00, 1376.41it/s]
In [22]:
from basepair.preproc import transform_data
In [23]:
train_nex, valid_nex, test_nex = transform_data(data2, use_profile, use_counts)
In [24]:
i += 1
model2, name2, ckp_file2 = get_model(mfn2, mkwargs2, fixed_kwargs, i)
history2 = model2.fit(train_nex[0], 
                    train_nex[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid_nex[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file2, save_best_only=True)]
         )
# get the best model
model2 = load_model(ckp_file2, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll, 
                                             "SpatialLifetimeSparsity": SpatialLifetimeSparsity})
Train on 5561 samples, validate on 1884 samples
Epoch 1/100
5561/5561 [==============================] - 4s 713us/step - loss: 763.1876 - profile/Sox2_loss: 299.5035 - profile/Oct4_loss: 430.2258 - counts/Sox2_loss: 1.0848 - counts/Oct4_loss: 2.2610 - val_loss: 704.6981 - val_profile/Sox2_loss: 280.1534 - val_profile/Oct4_loss: 394.2456 - val_counts/Sox2_loss: 1.0179 - val_counts/Oct4_loss: 2.0120
Epoch 2/100
5561/5561 [==============================] - 1s 163us/step - loss: 711.4473 - profile/Sox2_loss: 283.8509 - profile/Oct4_loss: 398.6259 - counts/Sox2_loss: 1.0038 - counts/Oct4_loss: 1.8933 - val_loss: 694.2177 - val_profile/Sox2_loss: 275.9324 - val_profile/Oct4_loss: 387.9595 - val_counts/Sox2_loss: 1.0208 - val_counts/Oct4_loss: 2.0118
Epoch 3/100
5561/5561 [==============================] - 1s 171us/step - loss: 704.8715 - profile/Sox2_loss: 281.0870 - profile/Oct4_loss: 394.7804 - counts/Sox2_loss: 1.0000 - counts/Oct4_loss: 1.9004 - val_loss: 690.3251 - val_profile/Sox2_loss: 274.3091 - val_profile/Oct4_loss: 385.7162 - val_counts/Sox2_loss: 1.0057 - val_counts/Oct4_loss: 2.0243
Epoch 4/100
5561/5561 [==============================] - 1s 159us/step - loss: 700.0234 - profile/Sox2_loss: 279.0358 - profile/Oct4_loss: 392.0777 - counts/Sox2_loss: 0.9979 - counts/Oct4_loss: 1.8930 - val_loss: 685.7401 - val_profile/Sox2_loss: 272.2211 - val_profile/Oct4_loss: 383.2517 - val_counts/Sox2_loss: 1.0085 - val_counts/Oct4_loss: 2.0182
Epoch 5/100
5561/5561 [==============================] - 1s 165us/step - loss: 693.5023 - profile/Sox2_loss: 276.4113 - profile/Oct4_loss: 388.1862 - counts/Sox2_loss: 0.9994 - counts/Oct4_loss: 1.8911 - val_loss: 679.7696 - val_profile/Sox2_loss: 270.1074 - val_profile/Oct4_loss: 379.4452 - val_counts/Sox2_loss: 1.0102 - val_counts/Oct4_loss: 2.0115
Epoch 6/100
5561/5561 [==============================] - 1s 168us/step - loss: 689.1575 - profile/Sox2_loss: 274.8336 - profile/Oct4_loss: 385.4323 - counts/Sox2_loss: 1.0023 - counts/Oct4_loss: 1.8869 - val_loss: 676.8876 - val_profile/Sox2_loss: 269.2669 - val_profile/Oct4_loss: 377.4304 - val_counts/Sox2_loss: 1.0088 - val_counts/Oct4_loss: 2.0102
Epoch 7/100
5561/5561 [==============================] - 1s 164us/step - loss: 685.0529 - profile/Sox2_loss: 273.5290 - profile/Oct4_loss: 382.6831 - counts/Sox2_loss: 1.0010 - counts/Oct4_loss: 1.8831 - val_loss: 674.2777 - val_profile/Sox2_loss: 268.3599 - val_profile/Oct4_loss: 375.8003 - val_counts/Sox2_loss: 1.0110 - val_counts/Oct4_loss: 2.0007
Epoch 8/100
5561/5561 [==============================] - 1s 148us/step - loss: 682.1885 - profile/Sox2_loss: 272.5930 - profile/Oct4_loss: 380.8218 - counts/Sox2_loss: 1.0010 - counts/Oct4_loss: 1.8764 - val_loss: 670.7889 - val_profile/Sox2_loss: 267.1580 - val_profile/Oct4_loss: 373.5313 - val_counts/Sox2_loss: 1.0078 - val_counts/Oct4_loss: 2.0022
Epoch 9/100
5561/5561 [==============================] - 1s 153us/step - loss: 678.8987 - profile/Sox2_loss: 271.4154 - profile/Oct4_loss: 378.8519 - counts/Sox2_loss: 0.9994 - counts/Oct4_loss: 1.8637 - val_loss: 667.7578 - val_profile/Sox2_loss: 266.0841 - val_profile/Oct4_loss: 371.7947 - val_counts/Sox2_loss: 1.0111 - val_counts/Oct4_loss: 1.9768
Epoch 10/100
5561/5561 [==============================] - 1s 146us/step - loss: 675.7578 - profile/Sox2_loss: 270.0719 - profile/Oct4_loss: 377.1737 - counts/Sox2_loss: 0.9971 - counts/Oct4_loss: 1.8541 - val_loss: 665.7501 - val_profile/Sox2_loss: 265.2902 - val_profile/Oct4_loss: 370.8071 - val_counts/Sox2_loss: 1.0088 - val_counts/Oct4_loss: 1.9565
Epoch 11/100
5561/5561 [==============================] - 1s 170us/step - loss: 672.5629 - profile/Sox2_loss: 268.8062 - profile/Oct4_loss: 375.4808 - counts/Sox2_loss: 0.9942 - counts/Oct4_loss: 1.8334 - val_loss: 663.3388 - val_profile/Sox2_loss: 264.3983 - val_profile/Oct4_loss: 369.3944 - val_counts/Sox2_loss: 1.0028 - val_counts/Oct4_loss: 1.9518
Epoch 12/100
5561/5561 [==============================] - 1s 160us/step - loss: 670.1690 - profile/Sox2_loss: 267.7120 - profile/Oct4_loss: 374.4065 - counts/Sox2_loss: 0.9897 - counts/Oct4_loss: 1.8154 - val_loss: 660.9843 - val_profile/Sox2_loss: 263.6724 - val_profile/Oct4_loss: 368.1142 - val_counts/Sox2_loss: 0.9977 - val_counts/Oct4_loss: 1.9221
Epoch 13/100
5561/5561 [==============================] - 1s 162us/step - loss: 668.1455 - profile/Sox2_loss: 266.9611 - profile/Oct4_loss: 373.4946 - counts/Sox2_loss: 0.9817 - counts/Oct4_loss: 1.7873 - val_loss: 659.9247 - val_profile/Sox2_loss: 263.2502 - val_profile/Oct4_loss: 368.0374 - val_counts/Sox2_loss: 0.9857 - val_counts/Oct4_loss: 1.8780
Epoch 14/100
5561/5561 [==============================] - 1s 152us/step - loss: 666.5748 - profile/Sox2_loss: 266.3696 - profile/Oct4_loss: 372.9276 - counts/Sox2_loss: 0.9714 - counts/Oct4_loss: 1.7564 - val_loss: 657.8648 - val_profile/Sox2_loss: 262.9321 - val_profile/Oct4_loss: 367.0214 - val_counts/Sox2_loss: 0.9679 - val_counts/Oct4_loss: 1.8232
Epoch 15/100
5561/5561 [==============================] - 1s 154us/step - loss: 664.8316 - profile/Sox2_loss: 265.8420 - profile/Oct4_loss: 372.4527 - counts/Sox2_loss: 0.9542 - counts/Oct4_loss: 1.6995 - val_loss: 656.2830 - val_profile/Sox2_loss: 262.3425 - val_profile/Oct4_loss: 366.9115 - val_counts/Sox2_loss: 0.9476 - val_counts/Oct4_loss: 1.7553
Epoch 16/100
5561/5561 [==============================] - 1s 153us/step - loss: 663.0010 - profile/Sox2_loss: 265.3578 - profile/Oct4_loss: 371.9924 - counts/Sox2_loss: 0.9304 - counts/Oct4_loss: 1.6347 - val_loss: 654.4368 - val_profile/Sox2_loss: 262.1156 - val_profile/Oct4_loss: 366.4358 - val_counts/Sox2_loss: 0.9179 - val_counts/Oct4_loss: 1.6706
Epoch 17/100
5561/5561 [==============================] - 1s 148us/step - loss: 661.6837 - profile/Sox2_loss: 265.1359 - profile/Oct4_loss: 371.9570 - counts/Sox2_loss: 0.9015 - counts/Oct4_loss: 1.5576 - val_loss: 653.3463 - val_profile/Sox2_loss: 262.1368 - val_profile/Oct4_loss: 366.5665 - val_counts/Sox2_loss: 0.8853 - val_counts/Oct4_loss: 1.5790
Epoch 18/100
5561/5561 [==============================] - 1s 165us/step - loss: 659.6782 - profile/Sox2_loss: 264.9507 - profile/Oct4_loss: 371.4946 - counts/Sox2_loss: 0.8760 - counts/Oct4_loss: 1.4473 - val_loss: 652.2367 - val_profile/Sox2_loss: 261.9912 - val_profile/Oct4_loss: 367.1882 - val_counts/Sox2_loss: 0.8572 - val_counts/Oct4_loss: 1.4486
Epoch 19/100
5561/5561 [==============================] - 1s 166us/step - loss: 659.2894 - profile/Sox2_loss: 264.5522 - profile/Oct4_loss: 371.5527 - counts/Sox2_loss: 0.8720 - counts/Oct4_loss: 1.4464 - val_loss: 650.7907 - val_profile/Sox2_loss: 261.9057 - val_profile/Oct4_loss: 366.0691 - val_counts/Sox2_loss: 0.8573 - val_counts/Oct4_loss: 1.4243
Epoch 20/100
5561/5561 [==============================] - 1s 168us/step - loss: 658.7922 - profile/Sox2_loss: 264.3513 - profile/Oct4_loss: 371.2292 - counts/Sox2_loss: 0.8913 - counts/Oct4_loss: 1.4299 - val_loss: 649.7325 - val_profile/Sox2_loss: 261.5394 - val_profile/Oct4_loss: 365.7616 - val_counts/Sox2_loss: 0.8496 - val_counts/Oct4_loss: 1.3936
Epoch 21/100
5561/5561 [==============================] - 1s 163us/step - loss: 655.8586 - profile/Sox2_loss: 263.7853 - profile/Oct4_loss: 370.4306 - counts/Sox2_loss: 0.8512 - counts/Oct4_loss: 1.3130 - val_loss: 647.7246 - val_profile/Sox2_loss: 260.9191 - val_profile/Oct4_loss: 364.7850 - val_counts/Sox2_loss: 0.8405 - val_counts/Oct4_loss: 1.3615
Epoch 22/100
5561/5561 [==============================] - 1s 161us/step - loss: 654.6104 - profile/Sox2_loss: 263.3866 - profile/Oct4_loss: 370.2460 - counts/Sox2_loss: 0.8385 - counts/Oct4_loss: 1.2593 - val_loss: 648.3401 - val_profile/Sox2_loss: 260.6118 - val_profile/Oct4_loss: 364.8391 - val_counts/Sox2_loss: 0.8669 - val_counts/Oct4_loss: 1.4221
Epoch 23/100
5561/5561 [==============================] - 1s 155us/step - loss: 655.2971 - profile/Sox2_loss: 263.2229 - profile/Oct4_loss: 370.2912 - counts/Sox2_loss: 0.8647 - counts/Oct4_loss: 1.3135 - val_loss: 647.4606 - val_profile/Sox2_loss: 260.6462 - val_profile/Oct4_loss: 364.8451 - val_counts/Sox2_loss: 0.8473 - val_counts/Oct4_loss: 1.3496
Epoch 24/100
5561/5561 [==============================] - 1s 157us/step - loss: 653.8092 - profile/Sox2_loss: 262.7660 - profile/Oct4_loss: 369.9246 - counts/Sox2_loss: 0.8519 - counts/Oct4_loss: 1.2600 - val_loss: 647.8885 - val_profile/Sox2_loss: 260.7352 - val_profile/Oct4_loss: 365.4378 - val_counts/Sox2_loss: 0.8491 - val_counts/Oct4_loss: 1.3224
Epoch 25/100
5561/5561 [==============================] - 1s 144us/step - loss: 653.4949 - profile/Sox2_loss: 262.3736 - profile/Oct4_loss: 369.8984 - counts/Sox2_loss: 0.8564 - counts/Oct4_loss: 1.2659 - val_loss: 646.6886 - val_profile/Sox2_loss: 260.0860 - val_profile/Oct4_loss: 363.9830 - val_counts/Sox2_loss: 0.8636 - val_counts/Oct4_loss: 1.3984
Epoch 26/100
5561/5561 [==============================] - 1s 155us/step - loss: 651.6279 - profile/Sox2_loss: 262.0433 - profile/Oct4_loss: 369.1362 - counts/Sox2_loss: 0.8374 - counts/Oct4_loss: 1.2074 - val_loss: 645.3668 - val_profile/Sox2_loss: 260.0960 - val_profile/Oct4_loss: 363.8404 - val_counts/Sox2_loss: 0.8386 - val_counts/Oct4_loss: 1.3044
Epoch 27/100
5561/5561 [==============================] - 1s 152us/step - loss: 651.4017 - profile/Sox2_loss: 261.9252 - profile/Oct4_loss: 369.0049 - counts/Sox2_loss: 0.8417 - counts/Oct4_loss: 1.2054 - val_loss: 647.0116 - val_profile/Sox2_loss: 260.1826 - val_profile/Oct4_loss: 363.8670 - val_counts/Sox2_loss: 0.8696 - val_counts/Oct4_loss: 1.4267
Epoch 28/100
5561/5561 [==============================] - 1s 148us/step - loss: 651.6443 - profile/Sox2_loss: 261.7078 - profile/Oct4_loss: 369.0637 - counts/Sox2_loss: 0.8490 - counts/Oct4_loss: 1.2383 - val_loss: 645.6603 - val_profile/Sox2_loss: 259.6499 - val_profile/Oct4_loss: 363.6593 - val_counts/Sox2_loss: 0.8601 - val_counts/Oct4_loss: 1.3750
Epoch 29/100
5561/5561 [==============================] - 1s 162us/step - loss: 650.7555 - profile/Sox2_loss: 261.4733 - profile/Oct4_loss: 368.8061 - counts/Sox2_loss: 0.8399 - counts/Oct4_loss: 1.2077 - val_loss: 644.2766 - val_profile/Sox2_loss: 259.8504 - val_profile/Oct4_loss: 363.2946 - val_counts/Sox2_loss: 0.8353 - val_counts/Oct4_loss: 1.2778
Epoch 30/100
5561/5561 [==============================] - 1s 155us/step - loss: 650.5687 - profile/Sox2_loss: 261.4790 - profile/Oct4_loss: 368.7938 - counts/Sox2_loss: 0.8411 - counts/Oct4_loss: 1.1885 - val_loss: 644.5857 - val_profile/Sox2_loss: 259.6583 - val_profile/Oct4_loss: 363.4537 - val_counts/Sox2_loss: 0.8467 - val_counts/Oct4_loss: 1.3007
Epoch 31/100
5561/5561 [==============================] - 1s 154us/step - loss: 649.4651 - profile/Sox2_loss: 261.2057 - profile/Oct4_loss: 368.3722 - counts/Sox2_loss: 0.8298 - counts/Oct4_loss: 1.1589 - val_loss: 644.2591 - val_profile/Sox2_loss: 259.6206 - val_profile/Oct4_loss: 363.5584 - val_counts/Sox2_loss: 0.8360 - val_counts/Oct4_loss: 1.2720
Epoch 32/100
5561/5561 [==============================] - 1s 150us/step - loss: 649.6294 - profile/Sox2_loss: 261.1612 - profile/Oct4_loss: 368.2828 - counts/Sox2_loss: 0.8373 - counts/Oct4_loss: 1.1813 - val_loss: 645.7333 - val_profile/Sox2_loss: 259.5055 - val_profile/Oct4_loss: 363.5964 - val_counts/Sox2_loss: 0.8767 - val_counts/Oct4_loss: 1.3864
Epoch 33/100
5561/5561 [==============================] - 1s 166us/step - loss: 648.9658 - profile/Sox2_loss: 260.9443 - profile/Oct4_loss: 367.9659 - counts/Sox2_loss: 0.8300 - counts/Oct4_loss: 1.1756 - val_loss: 644.4450 - val_profile/Sox2_loss: 259.3498 - val_profile/Oct4_loss: 363.0446 - val_counts/Sox2_loss: 0.8527 - val_counts/Oct4_loss: 1.3524
Epoch 34/100
5561/5561 [==============================] - 1s 164us/step - loss: 648.3730 - profile/Sox2_loss: 260.8905 - profile/Oct4_loss: 367.8714 - counts/Sox2_loss: 0.8239 - counts/Oct4_loss: 1.1372 - val_loss: 645.7116 - val_profile/Sox2_loss: 259.4803 - val_profile/Oct4_loss: 363.6060 - val_counts/Sox2_loss: 0.8720 - val_counts/Oct4_loss: 1.3905
Epoch 35/100
5561/5561 [==============================] - 1s 147us/step - loss: 648.6406 - profile/Sox2_loss: 260.8203 - profile/Oct4_loss: 367.9702 - counts/Sox2_loss: 0.8298 - counts/Oct4_loss: 1.1552 - val_loss: 643.7155 - val_profile/Sox2_loss: 259.3385 - val_profile/Oct4_loss: 362.7524 - val_counts/Sox2_loss: 0.8515 - val_counts/Oct4_loss: 1.3110
Epoch 36/100
5561/5561 [==============================] - 1s 167us/step - loss: 648.1554 - profile/Sox2_loss: 260.5897 - profile/Oct4_loss: 367.5790 - counts/Sox2_loss: 0.8293 - counts/Oct4_loss: 1.1694 - val_loss: 644.0491 - val_profile/Sox2_loss: 259.0871 - val_profile/Oct4_loss: 362.9778 - val_counts/Sox2_loss: 0.8596 - val_counts/Oct4_loss: 1.3388
Epoch 37/100
5561/5561 [==============================] - 1s 154us/step - loss: 647.0597 - profile/Sox2_loss: 260.4685 - profile/Oct4_loss: 367.3522 - counts/Sox2_loss: 0.8106 - counts/Oct4_loss: 1.1133 - val_loss: 643.0397 - val_profile/Sox2_loss: 259.2197 - val_profile/Oct4_loss: 362.5937 - val_counts/Sox2_loss: 0.8461 - val_counts/Oct4_loss: 1.2765
Epoch 38/100
5561/5561 [==============================] - 1s 159us/step - loss: 646.5270 - profile/Sox2_loss: 260.2998 - profile/Oct4_loss: 367.2337 - counts/Sox2_loss: 0.8038 - counts/Oct4_loss: 1.0955 - val_loss: 642.8960 - val_profile/Sox2_loss: 259.0908 - val_profile/Oct4_loss: 362.9462 - val_counts/Sox2_loss: 0.8306 - val_counts/Oct4_loss: 1.2553
Epoch 39/100
5561/5561 [==============================] - 1s 161us/step - loss: 646.8674 - profile/Sox2_loss: 260.3180 - profile/Oct4_loss: 367.2391 - counts/Sox2_loss: 0.8089 - counts/Oct4_loss: 1.1221 - val_loss: 644.8182 - val_profile/Sox2_loss: 258.9745 - val_profile/Oct4_loss: 363.2238 - val_counts/Sox2_loss: 0.8771 - val_counts/Oct4_loss: 1.3849
Epoch 40/100
5561/5561 [==============================] - 1s 161us/step - loss: 647.0439 - profile/Sox2_loss: 260.3818 - profile/Oct4_loss: 367.4644 - counts/Sox2_loss: 0.8077 - counts/Oct4_loss: 1.1121 - val_loss: 643.4172 - val_profile/Sox2_loss: 259.1207 - val_profile/Oct4_loss: 363.1606 - val_counts/Sox2_loss: 0.8292 - val_counts/Oct4_loss: 1.2844
Epoch 41/100
5561/5561 [==============================] - 1s 160us/step - loss: 646.2409 - profile/Sox2_loss: 260.2414 - profile/Oct4_loss: 367.1762 - counts/Sox2_loss: 0.7985 - counts/Oct4_loss: 1.0839 - val_loss: 642.6148 - val_profile/Sox2_loss: 258.7851 - val_profile/Oct4_loss: 362.6140 - val_counts/Sox2_loss: 0.8329 - val_counts/Oct4_loss: 1.2887
Epoch 42/100
5561/5561 [==============================] - 1s 162us/step - loss: 646.1598 - profile/Sox2_loss: 260.0739 - profile/Oct4_loss: 367.1427 - counts/Sox2_loss: 0.7986 - counts/Oct4_loss: 1.0957 - val_loss: 643.2369 - val_profile/Sox2_loss: 259.0764 - val_profile/Oct4_loss: 362.8208 - val_counts/Sox2_loss: 0.8550 - val_counts/Oct4_loss: 1.2790
Epoch 43/100
5561/5561 [==============================] - 1s 147us/step - loss: 645.7691 - profile/Sox2_loss: 259.9594 - profile/Oct4_loss: 366.8798 - counts/Sox2_loss: 0.8009 - counts/Oct4_loss: 1.0921 - val_loss: 642.1630 - val_profile/Sox2_loss: 258.8588 - val_profile/Oct4_loss: 362.5254 - val_counts/Sox2_loss: 0.8233 - val_counts/Oct4_loss: 1.2546
Epoch 44/100
5561/5561 [==============================] - 1s 157us/step - loss: 645.0132 - profile/Sox2_loss: 259.9233 - profile/Oct4_loss: 366.6893 - counts/Sox2_loss: 0.7816 - counts/Oct4_loss: 1.0585 - val_loss: 643.6229 - val_profile/Sox2_loss: 259.4802 - val_profile/Oct4_loss: 362.9609 - val_counts/Sox2_loss: 0.8441 - val_counts/Oct4_loss: 1.2741
Epoch 45/100
5561/5561 [==============================] - 1s 161us/step - loss: 644.9541 - profile/Sox2_loss: 259.8555 - profile/Oct4_loss: 366.5555 - counts/Sox2_loss: 0.7869 - counts/Oct4_loss: 1.0674 - val_loss: 641.8560 - val_profile/Sox2_loss: 258.8057 - val_profile/Oct4_loss: 362.2098 - val_counts/Sox2_loss: 0.8356 - val_counts/Oct4_loss: 1.2485
Epoch 46/100
5561/5561 [==============================] - 1s 171us/step - loss: 645.1118 - profile/Sox2_loss: 259.7875 - profile/Oct4_loss: 366.5365 - counts/Sox2_loss: 0.7983 - counts/Oct4_loss: 1.0805 - val_loss: 643.3075 - val_profile/Sox2_loss: 258.9488 - val_profile/Oct4_loss: 363.0432 - val_counts/Sox2_loss: 0.8327 - val_counts/Oct4_loss: 1.2989
Epoch 47/100
5561/5561 [==============================] - 1s 164us/step - loss: 644.4460 - profile/Sox2_loss: 259.8098 - profile/Oct4_loss: 366.5597 - counts/Sox2_loss: 0.7688 - counts/Oct4_loss: 1.0388 - val_loss: 641.5317 - val_profile/Sox2_loss: 258.5391 - val_profile/Oct4_loss: 362.2222 - val_counts/Sox2_loss: 0.8220 - val_counts/Oct4_loss: 1.2550
Epoch 48/100
5561/5561 [==============================] - 1s 148us/step - loss: 643.4010 - profile/Sox2_loss: 259.4304 - profile/Oct4_loss: 366.2212 - counts/Sox2_loss: 0.7536 - counts/Oct4_loss: 1.0213 - val_loss: 642.1645 - val_profile/Sox2_loss: 258.8702 - val_profile/Oct4_loss: 362.3011 - val_counts/Sox2_loss: 0.8504 - val_counts/Oct4_loss: 1.2489
Epoch 49/100
5561/5561 [==============================] - 1s 164us/step - loss: 644.1660 - profile/Sox2_loss: 259.5416 - profile/Oct4_loss: 366.2314 - counts/Sox2_loss: 0.7701 - counts/Oct4_loss: 1.0692 - val_loss: 641.9798 - val_profile/Sox2_loss: 258.5265 - val_profile/Oct4_loss: 362.3802 - val_counts/Sox2_loss: 0.8178 - val_counts/Oct4_loss: 1.2895
Epoch 50/100
5561/5561 [==============================] - 1s 159us/step - loss: 643.2476 - profile/Sox2_loss: 259.4182 - profile/Oct4_loss: 366.0651 - counts/Sox2_loss: 0.7569 - counts/Oct4_loss: 1.0195 - val_loss: 641.6686 - val_profile/Sox2_loss: 258.6051 - val_profile/Oct4_loss: 362.3847 - val_counts/Sox2_loss: 0.8233 - val_counts/Oct4_loss: 1.2446
Epoch 51/100
5561/5561 [==============================] - 1s 155us/step - loss: 643.5741 - profile/Sox2_loss: 259.4896 - profile/Oct4_loss: 366.0931 - counts/Sox2_loss: 0.7618 - counts/Oct4_loss: 1.0373 - val_loss: 641.6282 - val_profile/Sox2_loss: 258.7453 - val_profile/Oct4_loss: 362.1996 - val_counts/Sox2_loss: 0.8103 - val_counts/Oct4_loss: 1.2580
Epoch 52/100
5561/5561 [==============================] - 1s 153us/step - loss: 642.9900 - profile/Sox2_loss: 259.3362 - profile/Oct4_loss: 365.9861 - counts/Sox2_loss: 0.7458 - counts/Oct4_loss: 1.0210 - val_loss: 641.4141 - val_profile/Sox2_loss: 258.5274 - val_profile/Oct4_loss: 362.4702 - val_counts/Sox2_loss: 0.8094 - val_counts/Oct4_loss: 1.2322
Epoch 53/100
5561/5561 [==============================] - 1s 170us/step - loss: 642.7884 - profile/Sox2_loss: 259.3254 - profile/Oct4_loss: 365.9150 - counts/Sox2_loss: 0.7407 - counts/Oct4_loss: 1.0141 - val_loss: 641.0079 - val_profile/Sox2_loss: 258.4949 - val_profile/Oct4_loss: 362.0193 - val_counts/Sox2_loss: 0.8128 - val_counts/Oct4_loss: 1.2366
Epoch 54/100
5561/5561 [==============================] - 1s 168us/step - loss: 643.0095 - profile/Sox2_loss: 259.2841 - profile/Oct4_loss: 365.9236 - counts/Sox2_loss: 0.7511 - counts/Oct4_loss: 1.0291 - val_loss: 641.3983 - val_profile/Sox2_loss: 258.4825 - val_profile/Oct4_loss: 362.4983 - val_counts/Sox2_loss: 0.8051 - val_counts/Oct4_loss: 1.2367
Epoch 55/100
5561/5561 [==============================] - 1s 164us/step - loss: 642.0784 - profile/Sox2_loss: 259.1446 - profile/Oct4_loss: 365.7702 - counts/Sox2_loss: 0.7293 - counts/Oct4_loss: 0.9871 - val_loss: 640.9507 - val_profile/Sox2_loss: 258.4056 - val_profile/Oct4_loss: 362.1335 - val_counts/Sox2_loss: 0.8092 - val_counts/Oct4_loss: 1.2320
Epoch 56/100
5561/5561 [==============================] - 1s 164us/step - loss: 642.1487 - profile/Sox2_loss: 259.1180 - profile/Oct4_loss: 365.8502 - counts/Sox2_loss: 0.7228 - counts/Oct4_loss: 0.9952 - val_loss: 641.2518 - val_profile/Sox2_loss: 258.4739 - val_profile/Oct4_loss: 362.1233 - val_counts/Sox2_loss: 0.8153 - val_counts/Oct4_loss: 1.2501
Epoch 57/100
5561/5561 [==============================] - 1s 169us/step - loss: 643.6215 - profile/Sox2_loss: 259.2546 - profile/Oct4_loss: 365.7840 - counts/Sox2_loss: 0.7807 - counts/Oct4_loss: 1.0776 - val_loss: 640.8444 - val_profile/Sox2_loss: 258.5101 - val_profile/Oct4_loss: 362.1608 - val_counts/Sox2_loss: 0.7959 - val_counts/Oct4_loss: 1.2215
Epoch 58/100
5561/5561 [==============================] - 1s 165us/step - loss: 641.7561 - profile/Sox2_loss: 259.0468 - profile/Oct4_loss: 365.5287 - counts/Sox2_loss: 0.7293 - counts/Oct4_loss: 0.9887 - val_loss: 641.0524 - val_profile/Sox2_loss: 258.4442 - val_profile/Oct4_loss: 362.2276 - val_counts/Sox2_loss: 0.8062 - val_counts/Oct4_loss: 1.2319
Epoch 59/100
5561/5561 [==============================] - 1s 155us/step - loss: 641.4600 - profile/Sox2_loss: 258.9666 - profile/Oct4_loss: 365.4104 - counts/Sox2_loss: 0.7187 - counts/Oct4_loss: 0.9896 - val_loss: 640.8397 - val_profile/Sox2_loss: 258.2926 - val_profile/Oct4_loss: 362.2576 - val_counts/Sox2_loss: 0.8027 - val_counts/Oct4_loss: 1.2262
Epoch 60/100
5561/5561 [==============================] - 1s 154us/step - loss: 641.1823 - profile/Sox2_loss: 258.9066 - profile/Oct4_loss: 365.4787 - counts/Sox2_loss: 0.7090 - counts/Oct4_loss: 0.9707 - val_loss: 642.6642 - val_profile/Sox2_loss: 258.2801 - val_profile/Oct4_loss: 362.0169 - val_counts/Sox2_loss: 0.8434 - val_counts/Oct4_loss: 1.3934
Epoch 61/100
5561/5561 [==============================] - 1s 154us/step - loss: 641.1386 - profile/Sox2_loss: 258.9033 - profile/Oct4_loss: 365.2700 - counts/Sox2_loss: 0.7159 - counts/Oct4_loss: 0.9806 - val_loss: 640.6570 - val_profile/Sox2_loss: 258.3713 - val_profile/Oct4_loss: 361.9090 - val_counts/Sox2_loss: 0.8031 - val_counts/Oct4_loss: 1.2346
Epoch 62/100
5561/5561 [==============================] - 1s 158us/step - loss: 641.4627 - profile/Sox2_loss: 258.8020 - profile/Oct4_loss: 365.4253 - counts/Sox2_loss: 0.7192 - counts/Oct4_loss: 1.0043 - val_loss: 642.2461 - val_profile/Sox2_loss: 258.5030 - val_profile/Oct4_loss: 361.9809 - val_counts/Sox2_loss: 0.9272 - val_counts/Oct4_loss: 1.2491
Epoch 63/100
5561/5561 [==============================] - 1s 162us/step - loss: 640.5678 - profile/Sox2_loss: 258.7707 - profile/Oct4_loss: 365.0246 - counts/Sox2_loss: 0.6991 - counts/Oct4_loss: 0.9781 - val_loss: 642.1116 - val_profile/Sox2_loss: 258.5238 - val_profile/Oct4_loss: 362.2396 - val_counts/Sox2_loss: 0.8860 - val_counts/Oct4_loss: 1.2488
Epoch 64/100
5561/5561 [==============================] - 1s 158us/step - loss: 640.2506 - profile/Sox2_loss: 258.7458 - profile/Oct4_loss: 365.1427 - counts/Sox2_loss: 0.6945 - counts/Oct4_loss: 0.9417 - val_loss: 640.7278 - val_profile/Sox2_loss: 258.1557 - val_profile/Oct4_loss: 362.1599 - val_counts/Sox2_loss: 0.7959 - val_counts/Oct4_loss: 1.2453
Epoch 65/100
5561/5561 [==============================] - 1s 149us/step - loss: 639.8939 - profile/Sox2_loss: 258.6437 - profile/Oct4_loss: 365.0163 - counts/Sox2_loss: 0.6831 - counts/Oct4_loss: 0.9403 - val_loss: 642.5044 - val_profile/Sox2_loss: 258.4795 - val_profile/Oct4_loss: 362.1463 - val_counts/Sox2_loss: 0.9320 - val_counts/Oct4_loss: 1.2558
Epoch 66/100
5561/5561 [==============================] - 1s 156us/step - loss: 641.1632 - profile/Sox2_loss: 258.8763 - profile/Oct4_loss: 365.2085 - counts/Sox2_loss: 0.7286 - counts/Oct4_loss: 0.9793 - val_loss: 641.4021 - val_profile/Sox2_loss: 258.4098 - val_profile/Oct4_loss: 362.1405 - val_counts/Sox2_loss: 0.8402 - val_counts/Oct4_loss: 1.2450
In [25]:
evaluate(model2, valid_nex[0], valid_nex[1])
Out[25]:
{'loss': 640.6569518395037,
 'profile/Sox2_loss': 258.3712948679671,
 'profile/Oct4_loss': 361.9089973595492,
 'counts/Sox2_loss': 0.8031086461053026,
 'counts/Oct4_loss': 1.234557134077554}
In [26]:
from basepair import samplers
from basepair.plots import *
import pandas as pd
from basepair.math import softmax
import numpy as np
import keras.backend as K
from keras.models import Model
from concise.utils.plot import seqlogo_fig, seqlogo
import matplotlib.pyplot as plt
In [27]:
seqlogo
Out[27]:
<function concise.utils.plot.seqlogo(letter_heights, vocab='DNA', ax=None)>
In [28]:
test_nex[2].head()
Out[28]:
chr end oct4_neg oct4_pos sox2_neg sox2_pos start seq seq_id
2 chr1 57780311 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 3.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, ... [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 57780110 GAGAAAAATCATCTGGAATCCAGCTGAGAGTGAAAGGCGAGGCAAA... chr1:57780110-57780311
5 chr8 67966723 [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, ... [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 67966522 TAAACTCTCTTTCCCATATAGCCTCTCTCAGTTGCCCTTGCAATTT... chr8:67966522-67966723
7 chr1 9955454 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ... 9955253 ATAGAACCCATCCTTAGGGGAGTATCATGTGTACTTCATAGCCTGC... chr1:9955253-9955454
11 chr9 113438719 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... 113438518 CAGGTCTTACTTTCTGTCTCTGTAAGCTAACATAGGCCAATTGAGA... chr9:113438518-113438719
12 chr8 111452152 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 1.0, ... [0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, ... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ... 111451951 TACTGTACAGAATTTAGCATCACAATACATTAGCAACAGATCAAAC... chr8:111451951-111452152
In [29]:
p2 = Seq2Nexus(test_nex[0], test_nex[1], test_nex[2], model2)
p2.plot(sort='max_sox2', seq_grad='max', figsize=(20,12))
In [30]:
test[2].head()
Out[30]:
id chr start end task
0 0 chr1 74957633 74957834 Sox2
5 5 chr1 189805723 189805924 Sox2
13 13 chr9 61249702 61249903 Oct4
18 18 chr1 35220519 35220720 Oct4
19 19 chr8 124810872 124811073 Oct4
In [31]:
from basepair.BPNet import BPNetPredictor
from pybedtools import Interval, BedTool
In [32]:
bt = BedTool.from_dataframe(test[2][["chr", "start", "end"]][:5])
In [33]:
bpnet = BPNetPredictor(model, ds.fasta_file, list(ds.task_specs), preproc=preproc)
In [34]:
bpnet.predict_plot(intervals=list(bt), bws = ds2bws(ds), profile_grad="max")
In [35]:
bpnet2 = BPNetPredictor(model2, ds.fasta_file, list(ds.task_specs), preproc=preproc)
In [36]:
#bpnet2.predict_plot(intervals=list(bt), bws = ds2bws(ds), profile_grad="weighted")
In [37]:
diff = bpnet2.input_grad(test[0], seq_grad='weighted') - bpnet.input_grad(test[0], seq_grad='weighted')
diff = diff * test[0]
In [38]:
diff.shape
Out[38]:
(4254, 201, 4)
In [39]:
sums = np.sum(np.sum(diff, axis = 1), axis=1)
abs_sums = np.sum(np.sum(np.abs(diff), axis = 1), axis=1)
max_diff_index = np.argmax(sums)
no_diff_index = np.argmin(abs_sums)
most_diff_index = np.argmax(abs_sums)
min_diff_index = np.argmin(sums)
In [40]:
fig, (ax0, ax1, ax2, ax3)= plt.subplots(4, 1, sharex=True, figsize=(20, 6))

ax0.set_title("max_diff_index")
seqlogo(diff[max_diff_index], ax=ax0)

ax1.set_title("min_diff_index")
seqlogo(diff[min_diff_index], ax=ax1)

ax2.set_title("most_diff_index")
seqlogo(diff[most_diff_index], ax=ax2)

ax3.set_title("no_diff_index")
seqlogo(diff[no_diff_index], ax=ax3)
In [41]:
fig, (ax0, ax1, ax2, ax3)= plt.subplots(4, 1, sharex=True, figsize=(20, 6))

scale = 0.2

ax0.set_title("max_diff_index")
ax0.set_ylim((-1*scale,scale))
seqlogo(diff[max_diff_index], ax=ax0)

ax1.set_title("min_diff_index")
ax1.set_ylim((-1*scale,scale))
seqlogo(diff[min_diff_index], ax=ax1)

ax2.set_title("most_diff_index")
ax2.set_ylim((-1*scale,scale))
seqlogo(diff[most_diff_index], ax=ax2)

ax3.set_title("no_diff_index")
ax3.set_ylim((-1*scale,scale))
seqlogo(diff[no_diff_index], ax=ax3)
In [42]:
np.array(train[0]).shape
Out[42]:
(14727, 201, 4)
In [43]:
np.array(train_nex[0]).shape
Out[43]:
(5561, 201, 4)