In [1]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
Using TensorFlow backend.
In [3]:
# Use gpus 1, 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2"
In [34]:
ddir = '/users/amr1/bpnet/basepair/src/chipseq/'
In [8]:
bdir = "/users/amr1/bpnet/sox2-oct4-chipseq"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [9]:
def ds2bws(ds):
    return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
In [17]:
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
2018-12-10 14:50:23,655 [INFO] extract sequence
2018-12-10 14:50:26,079 [INFO] extract counts
100%|██████████| 2/2 [00:35<00:00, 17.70s/it]
In [18]:
train[1]['profile/Sox2'].shape
Out[18]:
(14727, 1000, 2)
In [19]:
train[1].keys()
Out[19]:
dict_keys(['profile/Sox2', 'profile/Oct4', 'counts/Sox2', 'counts/Oct4'])
In [20]:
def seq_multitask_chipseq(filters=21,
                          conv1_kernel_size=21,
                          tconv_kernel_size=25,
                          #tconv_kernel_size2=25,
                          n_dil_layers=6,
                          lr=0.004,
                          c_task_weight=100,
                          use_profile=True,
                          use_counts=True,
                          tasks=['sox2', 'oct4'],
                          seq_len=1000):
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
    """
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    prev_layers = [first_conv]
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = kl.add(prev_layers)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(2*len(tasks), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(len(tasks), kernel_size=(tconv_kernel_size2, 1), padding='same')(x)
    #x = kl.UpSampling2D((2, 1))(x)
    #x = kl.Conv2DTranspose(int(len(tasks)/2), kernel_size=(tconv_kernel_size3, 1), padding='same')(x)
    out = kl.Reshape((-1, 2 * len(tasks)))(x)

    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        output = [kl.Lambda(lambda x, i: x[:, :, (2 * i):(2 * i + 2)],
                            output_shape=(seq_len, 2),
                            name="profile/" + task,
                            arguments={"i": i})(out)
                  for i, task in enumerate(tasks)]
        outputs += output
        losses += [twochannel_multinomial_nll] * len(tasks)
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        counts = [kl.Dense(2, name="counts/" + task)(pooled)
                  for task in tasks]
        outputs += counts
        losses += ["mse"] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    model = Model(inp, outputs)
    model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
    return model
In [21]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
In [22]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [23]:
i=1
In [35]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file
In [36]:
# hyper-parameters
mfn = "seq_multitask_chipseq"
use_profile = True
use_counts = True
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=100,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               seq_len=train[0].shape[1],
               lr=0.004)
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
In [37]:
import numpy as np
np.random.seed(20)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0], 
                    train[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
Train on 14727 samples, validate on 4493 samples
Epoch 1/100
14727/14727 [==============================] - 6s 405us/step - loss: 616.3043 - profile/Sox2_loss: 168.1242 - profile/Oct4_loss: 425.4187 - counts/Sox2_loss: 1.0797 - counts/Oct4_loss: 1.1964 - val_loss: 611.3059 - val_profile/Sox2_loss: 167.3084 - val_profile/Oct4_loss: 423.6573 - val_counts/Sox2_loss: 1.0032 - val_counts/Oct4_loss: 1.0308
Epoch 2/100
14727/14727 [==============================] - 4s 257us/step - loss: 590.2564 - profile/Sox2_loss: 164.2604 - profile/Oct4_loss: 405.9931 - counts/Sox2_loss: 0.9991 - counts/Oct4_loss: 1.0012 - val_loss: 590.7519 - val_profile/Sox2_loss: 164.4099 - val_profile/Oct4_loss: 406.0800 - val_counts/Sox2_loss: 1.0014 - val_counts/Oct4_loss: 1.0248
Epoch 3/100
14727/14727 [==============================] - 4s 251us/step - loss: 577.4715 - profile/Sox2_loss: 162.3749 - profile/Oct4_loss: 395.1156 - counts/Sox2_loss: 0.9975 - counts/Oct4_loss: 1.0006 - val_loss: 586.8555 - val_profile/Sox2_loss: 163.7865 - val_profile/Oct4_loss: 402.8258 - val_counts/Sox2_loss: 1.0000 - val_counts/Oct4_loss: 1.0243
Epoch 4/100
14727/14727 [==============================] - 3s 237us/step - loss: 575.1964 - profile/Sox2_loss: 161.9651 - profile/Oct4_loss: 393.2963 - counts/Sox2_loss: 0.9937 - counts/Oct4_loss: 0.9998 - val_loss: 585.1812 - val_profile/Sox2_loss: 163.3500 - val_profile/Oct4_loss: 401.6507 - val_counts/Sox2_loss: 0.9948 - val_counts/Oct4_loss: 1.0233
Epoch 5/100
14727/14727 [==============================] - 4s 257us/step - loss: 573.5990 - profile/Sox2_loss: 161.5765 - profile/Oct4_loss: 392.1314 - counts/Sox2_loss: 0.9913 - counts/Oct4_loss: 0.9978 - val_loss: 583.7922 - val_profile/Sox2_loss: 162.9198 - val_profile/Oct4_loss: 400.6958 - val_counts/Sox2_loss: 0.9932 - val_counts/Oct4_loss: 1.0245
Epoch 6/100
14727/14727 [==============================] - 4s 246us/step - loss: 572.2051 - profile/Sox2_loss: 161.0577 - profile/Oct4_loss: 391.3040 - counts/Sox2_loss: 0.9887 - counts/Oct4_loss: 0.9957 - val_loss: 582.0773 - val_profile/Sox2_loss: 162.4752 - val_profile/Oct4_loss: 399.5147 - val_counts/Sox2_loss: 0.9884 - val_counts/Oct4_loss: 1.0203
Epoch 7/100
14727/14727 [==============================] - 4s 254us/step - loss: 571.2587 - profile/Sox2_loss: 160.8014 - profile/Oct4_loss: 390.6697 - counts/Sox2_loss: 0.9848 - counts/Oct4_loss: 0.9940 - val_loss: 581.2158 - val_profile/Sox2_loss: 162.1971 - val_profile/Oct4_loss: 399.0075 - val_counts/Sox2_loss: 0.9816 - val_counts/Oct4_loss: 1.0195
Epoch 8/100
14727/14727 [==============================] - 4s 259us/step - loss: 570.3831 - profile/Sox2_loss: 160.6253 - profile/Oct4_loss: 390.0540 - counts/Sox2_loss: 0.9804 - counts/Oct4_loss: 0.9900 - val_loss: 580.9574 - val_profile/Sox2_loss: 162.2492 - val_profile/Oct4_loss: 398.7572 - val_counts/Sox2_loss: 0.9800 - val_counts/Oct4_loss: 1.0151
Epoch 9/100
14727/14727 [==============================] - 4s 260us/step - loss: 569.4968 - profile/Sox2_loss: 160.4502 - profile/Oct4_loss: 389.4437 - counts/Sox2_loss: 0.9750 - counts/Oct4_loss: 0.9853 - val_loss: 579.7177 - val_profile/Sox2_loss: 162.1709 - val_profile/Oct4_loss: 397.7184 - val_counts/Sox2_loss: 0.9758 - val_counts/Oct4_loss: 1.0071
Epoch 10/100
14727/14727 [==============================] - 4s 255us/step - loss: 569.0671 - profile/Sox2_loss: 160.3743 - profile/Oct4_loss: 389.1812 - counts/Sox2_loss: 0.9703 - counts/Oct4_loss: 0.9809 - val_loss: 580.0222 - val_profile/Sox2_loss: 162.2003 - val_profile/Oct4_loss: 398.0500 - val_counts/Sox2_loss: 0.9711 - val_counts/Oct4_loss: 1.0061
Epoch 11/100
14727/14727 [==============================] - 4s 238us/step - loss: 568.5817 - profile/Sox2_loss: 160.3136 - profile/Oct4_loss: 388.8081 - counts/Sox2_loss: 0.9680 - counts/Oct4_loss: 0.9780 - val_loss: 579.7106 - val_profile/Sox2_loss: 162.1655 - val_profile/Oct4_loss: 397.9457 - val_counts/Sox2_loss: 0.9584 - val_counts/Oct4_loss: 1.0016
Epoch 12/100
14727/14727 [==============================] - 4s 263us/step - loss: 567.8535 - profile/Sox2_loss: 160.1345 - profile/Oct4_loss: 388.4216 - counts/Sox2_loss: 0.9590 - counts/Oct4_loss: 0.9708 - val_loss: 578.9320 - val_profile/Sox2_loss: 162.0278 - val_profile/Oct4_loss: 397.5353 - val_counts/Sox2_loss: 0.9503 - val_counts/Oct4_loss: 0.9866
Epoch 13/100
14727/14727 [==============================] - 4s 244us/step - loss: 567.1428 - profile/Sox2_loss: 160.0022 - profile/Oct4_loss: 388.0339 - counts/Sox2_loss: 0.9483 - counts/Oct4_loss: 0.9624 - val_loss: 578.0977 - val_profile/Sox2_loss: 161.8977 - val_profile/Oct4_loss: 396.9713 - val_counts/Sox2_loss: 0.9441 - val_counts/Oct4_loss: 0.9788
Epoch 14/100
14727/14727 [==============================] - 4s 257us/step - loss: 567.1830 - profile/Sox2_loss: 159.9881 - profile/Oct4_loss: 388.0255 - counts/Sox2_loss: 0.9530 - counts/Oct4_loss: 0.9639 - val_loss: 580.1684 - val_profile/Sox2_loss: 162.0504 - val_profile/Oct4_loss: 397.8619 - val_counts/Sox2_loss: 0.9982 - val_counts/Oct4_loss: 1.0274
Epoch 15/100
14727/14727 [==============================] - 4s 242us/step - loss: 566.8313 - profile/Sox2_loss: 159.9450 - profile/Oct4_loss: 387.6475 - counts/Sox2_loss: 0.9562 - counts/Oct4_loss: 0.9677 - val_loss: 577.8240 - val_profile/Sox2_loss: 161.7873 - val_profile/Oct4_loss: 396.7120 - val_counts/Sox2_loss: 0.9519 - val_counts/Oct4_loss: 0.9805
Epoch 16/100
14727/14727 [==============================] - 4s 258us/step - loss: 566.0107 - profile/Sox2_loss: 159.7577 - profile/Oct4_loss: 387.3083 - counts/Sox2_loss: 0.9395 - counts/Oct4_loss: 0.9550 - val_loss: 577.6982 - val_profile/Sox2_loss: 161.8321 - val_profile/Oct4_loss: 396.7618 - val_counts/Sox2_loss: 0.9395 - val_counts/Oct4_loss: 0.9710
Epoch 17/100
14727/14727 [==============================] - 3s 235us/step - loss: 565.9663 - profile/Sox2_loss: 159.7618 - profile/Oct4_loss: 387.2787 - counts/Sox2_loss: 0.9391 - counts/Oct4_loss: 0.9535 - val_loss: 580.9153 - val_profile/Sox2_loss: 162.3516 - val_profile/Oct4_loss: 397.6583 - val_counts/Sox2_loss: 1.0167 - val_counts/Oct4_loss: 1.0738
Epoch 18/100
14727/14727 [==============================] - 3s 235us/step - loss: 565.9336 - profile/Sox2_loss: 159.7685 - profile/Oct4_loss: 387.1091 - counts/Sox2_loss: 0.9482 - counts/Oct4_loss: 0.9574 - val_loss: 577.1707 - val_profile/Sox2_loss: 161.6850 - val_profile/Oct4_loss: 396.4513 - val_counts/Sox2_loss: 0.9320 - val_counts/Oct4_loss: 0.9714
Epoch 19/100
14727/14727 [==============================] - 4s 257us/step - loss: 564.9943 - profile/Sox2_loss: 159.6331 - profile/Oct4_loss: 386.6012 - counts/Sox2_loss: 0.9295 - counts/Oct4_loss: 0.9465 - val_loss: 576.7848 - val_profile/Sox2_loss: 161.7297 - val_profile/Oct4_loss: 396.0803 - val_counts/Sox2_loss: 0.9312 - val_counts/Oct4_loss: 0.9662
Epoch 20/100
14727/14727 [==============================] - 4s 263us/step - loss: 564.6574 - profile/Sox2_loss: 159.5712 - profile/Oct4_loss: 386.4945 - counts/Sox2_loss: 0.9207 - counts/Oct4_loss: 0.9385 - val_loss: 577.5595 - val_profile/Sox2_loss: 161.7426 - val_profile/Oct4_loss: 396.5655 - val_counts/Sox2_loss: 0.9623 - val_counts/Oct4_loss: 0.9629
Epoch 21/100
14727/14727 [==============================] - 4s 246us/step - loss: 564.8206 - profile/Sox2_loss: 159.5770 - profile/Oct4_loss: 386.5343 - counts/Sox2_loss: 0.9277 - counts/Oct4_loss: 0.9432 - val_loss: 577.6107 - val_profile/Sox2_loss: 161.9086 - val_profile/Oct4_loss: 396.6038 - val_counts/Sox2_loss: 0.9312 - val_counts/Oct4_loss: 0.9787
Epoch 22/100
14727/14727 [==============================] - 3s 235us/step - loss: 564.1826 - profile/Sox2_loss: 159.4838 - profile/Oct4_loss: 386.2453 - counts/Sox2_loss: 0.9135 - counts/Oct4_loss: 0.9319 - val_loss: 577.1205 - val_profile/Sox2_loss: 161.7578 - val_profile/Oct4_loss: 396.2209 - val_counts/Sox2_loss: 0.9468 - val_counts/Oct4_loss: 0.9674
Epoch 23/100
14727/14727 [==============================] - 4s 245us/step - loss: 564.0137 - profile/Sox2_loss: 159.4730 - profile/Oct4_loss: 386.0992 - counts/Sox2_loss: 0.9141 - counts/Oct4_loss: 0.9300 - val_loss: 577.2051 - val_profile/Sox2_loss: 161.8083 - val_profile/Oct4_loss: 396.2086 - val_counts/Sox2_loss: 0.9359 - val_counts/Oct4_loss: 0.9829
Epoch 24/100
14727/14727 [==============================] - 4s 274us/step - loss: 563.7202 - profile/Sox2_loss: 159.4214 - profile/Oct4_loss: 385.9509 - counts/Sox2_loss: 0.9078 - counts/Oct4_loss: 0.9270 - val_loss: 576.3659 - val_profile/Sox2_loss: 161.7943 - val_profile/Oct4_loss: 396.0562 - val_counts/Sox2_loss: 0.9105 - val_counts/Oct4_loss: 0.9410
Epoch 25/100
14727/14727 [==============================] - 4s 265us/step - loss: 563.4709 - profile/Sox2_loss: 159.3572 - profile/Oct4_loss: 385.8967 - counts/Sox2_loss: 0.9009 - counts/Oct4_loss: 0.9208 - val_loss: 576.3726 - val_profile/Sox2_loss: 162.0489 - val_profile/Oct4_loss: 395.6539 - val_counts/Sox2_loss: 0.9187 - val_counts/Oct4_loss: 0.9483
Epoch 26/100
14727/14727 [==============================] - 4s 266us/step - loss: 563.4591 - profile/Sox2_loss: 159.3731 - profile/Oct4_loss: 385.6763 - counts/Sox2_loss: 0.9101 - counts/Oct4_loss: 0.9309 - val_loss: 575.7679 - val_profile/Sox2_loss: 161.7525 - val_profile/Oct4_loss: 395.5991 - val_counts/Sox2_loss: 0.9066 - val_counts/Oct4_loss: 0.9350
Epoch 27/100
14727/14727 [==============================] - 4s 258us/step - loss: 562.8053 - profile/Sox2_loss: 159.2699 - profile/Oct4_loss: 385.5152 - counts/Sox2_loss: 0.8894 - counts/Oct4_loss: 0.9126 - val_loss: 576.5256 - val_profile/Sox2_loss: 161.7623 - val_profile/Oct4_loss: 395.8993 - val_counts/Sox2_loss: 0.9301 - val_counts/Oct4_loss: 0.9563
Epoch 28/100
14727/14727 [==============================] - 4s 257us/step - loss: 563.3680 - profile/Sox2_loss: 159.3614 - profile/Oct4_loss: 385.7241 - counts/Sox2_loss: 0.9043 - counts/Oct4_loss: 0.9239 - val_loss: 576.0801 - val_profile/Sox2_loss: 161.6792 - val_profile/Oct4_loss: 396.0294 - val_counts/Sox2_loss: 0.8984 - val_counts/Oct4_loss: 0.9388
Epoch 29/100
14727/14727 [==============================] - 4s 259us/step - loss: 562.5983 - profile/Sox2_loss: 159.2231 - profile/Oct4_loss: 385.3777 - counts/Sox2_loss: 0.8884 - counts/Oct4_loss: 0.9114 - val_loss: 575.3644 - val_profile/Sox2_loss: 161.7302 - val_profile/Oct4_loss: 395.5075 - val_counts/Sox2_loss: 0.8897 - val_counts/Oct4_loss: 0.9230
Epoch 30/100
14727/14727 [==============================] - 4s 244us/step - loss: 562.7200 - profile/Sox2_loss: 159.2439 - profile/Oct4_loss: 385.5345 - counts/Sox2_loss: 0.8856 - counts/Oct4_loss: 0.9086 - val_loss: 576.0619 - val_profile/Sox2_loss: 161.6343 - val_profile/Oct4_loss: 395.8055 - val_counts/Sox2_loss: 0.9087 - val_counts/Oct4_loss: 0.9535
Epoch 31/100
14727/14727 [==============================] - 4s 256us/step - loss: 561.9716 - profile/Sox2_loss: 159.1604 - profile/Oct4_loss: 385.0812 - counts/Sox2_loss: 0.8739 - counts/Oct4_loss: 0.8991 - val_loss: 575.5868 - val_profile/Sox2_loss: 161.7151 - val_profile/Oct4_loss: 395.8610 - val_counts/Sox2_loss: 0.8735 - val_counts/Oct4_loss: 0.9276
Epoch 32/100
14727/14727 [==============================] - 4s 241us/step - loss: 561.6637 - profile/Sox2_loss: 159.1840 - profile/Oct4_loss: 385.0422 - counts/Sox2_loss: 0.8588 - counts/Oct4_loss: 0.8849 - val_loss: 575.2367 - val_profile/Sox2_loss: 161.7039 - val_profile/Oct4_loss: 395.7971 - val_counts/Sox2_loss: 0.8638 - val_counts/Oct4_loss: 0.9097
Epoch 33/100
14727/14727 [==============================] - 4s 261us/step - loss: 561.2623 - profile/Sox2_loss: 159.1619 - profile/Oct4_loss: 384.8515 - counts/Sox2_loss: 0.8474 - counts/Oct4_loss: 0.8775 - val_loss: 575.7079 - val_profile/Sox2_loss: 161.9301 - val_profile/Oct4_loss: 395.9835 - val_counts/Sox2_loss: 0.8730 - val_counts/Oct4_loss: 0.9064
Epoch 34/100
14727/14727 [==============================] - 4s 257us/step - loss: 561.1369 - profile/Sox2_loss: 159.1470 - profile/Oct4_loss: 384.8371 - counts/Sox2_loss: 0.8430 - counts/Oct4_loss: 0.8722 - val_loss: 575.7895 - val_profile/Sox2_loss: 161.8789 - val_profile/Oct4_loss: 396.2020 - val_counts/Sox2_loss: 0.8633 - val_counts/Oct4_loss: 0.9076
Epoch 35/100
14727/14727 [==============================] - 4s 244us/step - loss: 561.1009 - profile/Sox2_loss: 159.1508 - profile/Oct4_loss: 384.7104 - counts/Sox2_loss: 0.8471 - counts/Oct4_loss: 0.8769 - val_loss: 579.0939 - val_profile/Sox2_loss: 162.1038 - val_profile/Oct4_loss: 397.2945 - val_counts/Sox2_loss: 0.9833 - val_counts/Oct4_loss: 0.9862
Epoch 36/100
14727/14727 [==============================] - 4s 256us/step - loss: 561.2505 - profile/Sox2_loss: 159.1358 - profile/Oct4_loss: 384.7608 - counts/Sox2_loss: 0.8548 - counts/Oct4_loss: 0.8806 - val_loss: 575.6743 - val_profile/Sox2_loss: 161.7469 - val_profile/Oct4_loss: 395.6983 - val_counts/Sox2_loss: 0.8744 - val_counts/Oct4_loss: 0.9485
Epoch 37/100
14727/14727 [==============================] - 4s 246us/step - loss: 560.9825 - profile/Sox2_loss: 159.0951 - profile/Oct4_loss: 384.5865 - counts/Sox2_loss: 0.8545 - counts/Oct4_loss: 0.8756 - val_loss: 575.0987 - val_profile/Sox2_loss: 161.8750 - val_profile/Oct4_loss: 395.5772 - val_counts/Sox2_loss: 0.8582 - val_counts/Oct4_loss: 0.9065
Epoch 38/100
14727/14727 [==============================] - 3s 236us/step - loss: 560.3417 - profile/Sox2_loss: 159.0262 - profile/Oct4_loss: 384.4705 - counts/Sox2_loss: 0.8262 - counts/Oct4_loss: 0.8583 - val_loss: 575.8286 - val_profile/Sox2_loss: 161.8450 - val_profile/Oct4_loss: 395.5239 - val_counts/Sox2_loss: 0.9058 - val_counts/Oct4_loss: 0.9402
Epoch 39/100
14727/14727 [==============================] - 4s 262us/step - loss: 560.2087 - profile/Sox2_loss: 159.0254 - profile/Oct4_loss: 384.3639 - counts/Sox2_loss: 0.8240 - counts/Oct4_loss: 0.8579 - val_loss: 577.3658 - val_profile/Sox2_loss: 161.9570 - val_profile/Oct4_loss: 396.2150 - val_counts/Sox2_loss: 0.9351 - val_counts/Oct4_loss: 0.9843
Epoch 40/100
14727/14727 [==============================] - 4s 264us/step - loss: 560.2889 - profile/Sox2_loss: 159.0011 - profile/Oct4_loss: 384.3148 - counts/Sox2_loss: 0.8329 - counts/Oct4_loss: 0.8644 - val_loss: 575.4233 - val_profile/Sox2_loss: 161.7493 - val_profile/Oct4_loss: 395.7647 - val_counts/Sox2_loss: 0.8665 - val_counts/Oct4_loss: 0.9244
Epoch 41/100
14727/14727 [==============================] - 4s 238us/step - loss: 559.8002 - profile/Sox2_loss: 158.9799 - profile/Oct4_loss: 384.1745 - counts/Sox2_loss: 0.8164 - counts/Oct4_loss: 0.8482 - val_loss: 575.3028 - val_profile/Sox2_loss: 161.8802 - val_profile/Oct4_loss: 396.0020 - val_counts/Sox2_loss: 0.8432 - val_counts/Oct4_loss: 0.8988
Epoch 42/100
14727/14727 [==============================] - 4s 238us/step - loss: 559.4482 - profile/Sox2_loss: 158.9706 - profile/Oct4_loss: 384.0877 - counts/Sox2_loss: 0.8005 - counts/Oct4_loss: 0.8385 - val_loss: 577.2816 - val_profile/Sox2_loss: 162.1758 - val_profile/Oct4_loss: 396.3797 - val_counts/Sox2_loss: 0.9303 - val_counts/Oct4_loss: 0.9423
In [38]:
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
Out[38]:
{'loss': 575.0987344381408,
 'profile/Sox2_loss': 161.87499498052745,
 'profile/Oct4_loss': 395.5772104643247,
 'counts/Sox2_loss': 0.8582015076538992,
 'counts/Oct4_loss': 0.9064518255193011}
In [ ]: