In [1]:
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
Using TensorFlow backend.
2018-09-24 17:49:10,707 [WARNING] git-lfs not installed
2018-09-24 17:49:10,724 [WARNING] git-lfs not installed
In [2]:
# Use gpus 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5, 7"
In [3]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
In [4]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [5]:
def get_model(mfn, mkwargs, fkwargs, i):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(i)
    i+=1
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    all_kwargs = {**mkwargs, **fkwargs}
    return eval(mfn)(**all_kwargs), name, ckp_file
In [6]:
BED_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
Sox2_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Sox2/"
Oct4_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Oct4/"
In [7]:
import pandas as pd
import numpy as np
from pybedtools import BedTool, Interval
from basepair.config import get_data_dir
from basepair.preproc import bin_counts
from concise.utils.helper import get_from_module
from tqdm import tqdm
from concise.preprocessing import encodeDNA
from random import Random
import joblib
from basepair.preproc import resize_interval
from genomelake.extractors import FastaExtractor, BigwigExtractor
from kipoi.data_utils import get_dataset_item
import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
In [8]:
def get_chipnexus_data(bed_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.narrowPeak",
                       peak_fasta_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.fasta",
                       bigwigs={"cuts_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
                                 "cuts_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
                                }
                       ):
    """Loads the dataframe for sox2
    """
    from concise.utils.fasta import read_fasta
    import pyBigWig

    fas = read_fasta(peak_fasta_file)

    bed = BedTool(bed_file)

    assert len(fas) == len(bed)

    bigwig_obj = {k: pyBigWig.open(v) for k, v in bigwigs.items()}

    # cuts_pos = []
    # cuts_neg = []
    l = []
    for interval in tqdm(bed):
        l.append({"chr": interval.chrom,
                  "start": interval.start,
                  "end": interval.stop,
                  **{k: np.nan_to_num(bw.values(interval.chrom,
                                                interval.start,
                                                interval.stop,
                                                numpy=True))
                     for k, bw in bigwig_obj.items()}
                  })

    dfc = pd.DataFrame(l)

    dfc['seq'] = list(fas.values())
    dfc['seq_id'] = list(fas)

    dfc['seq'] = dfc.seq.str.upper()
    return dfc
In [9]:
def sox2_oct4_peaks_sox2(valid_chr=['chr2', 'chr3', 'chr4'],
                         test_chr=['chr1', 'chr8', 'chr9']):
    """
    The default chromomsome split is roughly 60/20/20
    """
    for v in valid_chr:
        assert v not in test_chr

    dfc = get_chipnexus_data(
        bigwigs={"sox2_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
                 "sox2_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
                 "oct4_pos": f"{Oct4_BW_DIR}/Oct4_1234.pos.bw",
                 "oct4_neg": f"{Oct4_BW_DIR}/Oct4_1234.neg.bw"})
    seq = encodeDNA(dfc.seq)

    # Prepare the signal
    sox2_pos = np.stack(dfc.sox2_pos)
    sox2_neg = np.stack(dfc.sox2_neg)
    oct4_pos = np.stack(dfc.oct4_pos)
    oct4_neg = np.stack(dfc.oct4_neg)

    ids = dfc.seq_id

    # Data splits
    is_test = dfc.chr.isin(test_chr)
    is_valid = dfc.chr.isin(valid_chr)
    is_train = (~is_test) & (~is_valid)

    sox2_cuts = np.stack([sox2_pos, sox2_neg], axis=-1)
    oct4_cuts = np.stack([oct4_pos, oct4_neg], axis=-1)

    return tuple(((seq[subset],  # x
                   {"sox2": sox2_cuts[subset],  # y
                    "oct4": oct4_cuts[subset]},
                   dfc[subset])  # metadata
                  for subset in [is_train, is_valid, is_test]))
In [10]:
# hyper-parameters
from basepair.models import seq_multitask
mfn2 = "seq_multitask"
use_profile = True
use_counts = True
mkwargs2 = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=25,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               lr=0.004)
In [11]:
data2 = sox2_oct4_peaks_sox2()
100%|██████████| 9396/9396 [07:06<00:00, 22.05it/s]
In [12]:
from basepair.preproc import transform_data
In [13]:
train_nex, valid_nex, test_nex = transform_data(data2, use_profile, use_counts)
In [20]:
ddir = get_data_dir()

bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"

ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
                                           pos_counts=f"{bdir}/Sox2/pos.bw",
                                           neg_counts=f"{bdir}/Sox2/neg.bw",
                                           peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
                                          ),
                          "Oct4": TaskSpec(task="Oct4",
                                           pos_counts=f"{bdir}/Oct4/pos2.bw",
                                           neg_counts=f"{bdir}/Oct4/neg2.bw",
                                           peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
                                          )
                         },
              fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
             )
In [21]:
fixed_kwargs = dict(
    tasks=list(ds.task_specs)
)
i=1
In [22]:
i += 1
model2, name2, ckp_file2 = get_model(mfn2, mkwargs2, fixed_kwargs, i)
history2 = model2.fit(train_nex[0], 
                    train_nex[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid_nex[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file2, save_best_only=True)]
         )
# get the best model
model2 = load_model(ckp_file2, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll, 
                                             "SpatialLifetimeSparsity": SpatialLifetimeSparsity})
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-09-24 18:01:09,446 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-09-24 18:01:15,916 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
Train on 5561 samples, validate on 1884 samples
Epoch 1/100
5561/5561 [==============================] - 11s 2ms/step - loss: 759.2075 - profile/Sox2_loss: 297.5083 - profile/Oct4_loss: 430.4526 - counts/Sox2_loss: 1.0468 - counts/Oct4_loss: 2.0779 - val_loss: 705.0051 - val_profile/Sox2_loss: 280.0431 - val_profile/Oct4_loss: 394.5018 - val_counts/Sox2_loss: 1.0142 - val_counts/Oct4_loss: 2.0319
Epoch 2/100
5561/5561 [==============================] - 1s 138us/step - loss: 709.7728 - profile/Sox2_loss: 283.0120 - profile/Oct4_loss: 397.7276 - counts/Sox2_loss: 1.0054 - counts/Oct4_loss: 1.8979 - val_loss: 691.4600 - val_profile/Sox2_loss: 274.7748 - val_profile/Oct4_loss: 386.2711 - val_counts/Sox2_loss: 1.0122 - val_counts/Oct4_loss: 2.0293
Epoch 3/100
5561/5561 [==============================] - 1s 141us/step - loss: 700.3242 - profile/Sox2_loss: 279.0755 - profile/Oct4_loss: 392.2231 - counts/Sox2_loss: 1.0043 - counts/Oct4_loss: 1.8982 - val_loss: 684.4440 - val_profile/Sox2_loss: 271.7966 - val_profile/Oct4_loss: 382.1573 - val_counts/Sox2_loss: 1.0125 - val_counts/Oct4_loss: 2.0365
Epoch 4/100
5561/5561 [==============================] - 1s 139us/step - loss: 692.4646 - profile/Sox2_loss: 275.8147 - profile/Oct4_loss: 387.6496 - counts/Sox2_loss: 1.0051 - counts/Oct4_loss: 1.8949 - val_loss: 678.4818 - val_profile/Sox2_loss: 269.4294 - val_profile/Oct4_loss: 378.6340 - val_counts/Sox2_loss: 1.0149 - val_counts/Oct4_loss: 2.0270
Epoch 5/100
5561/5561 [==============================] - 1s 146us/step - loss: 686.3894 - profile/Sox2_loss: 273.5289 - profile/Oct4_loss: 383.9204 - counts/Sox2_loss: 1.0055 - counts/Oct4_loss: 1.8885 - val_loss: 674.4888 - val_profile/Sox2_loss: 268.1489 - val_profile/Oct4_loss: 375.9884 - val_counts/Sox2_loss: 1.0138 - val_counts/Oct4_loss: 2.0214
Epoch 6/100
5561/5561 [==============================] - 1s 142us/step - loss: 682.1938 - profile/Sox2_loss: 272.0254 - profile/Oct4_loss: 381.3407 - counts/Sox2_loss: 1.0062 - counts/Oct4_loss: 1.8766 - val_loss: 671.0906 - val_profile/Sox2_loss: 266.8637 - val_profile/Oct4_loss: 374.1942 - val_counts/Sox2_loss: 1.0155 - val_counts/Oct4_loss: 1.9878
Epoch 7/100
5561/5561 [==============================] - 1s 137us/step - loss: 679.3557 - profile/Sox2_loss: 270.8776 - profile/Oct4_loss: 379.7415 - counts/Sox2_loss: 1.0083 - counts/Oct4_loss: 1.8654 - val_loss: 668.6042 - val_profile/Sox2_loss: 265.9410 - val_profile/Oct4_loss: 372.7856 - val_counts/Sox2_loss: 1.0104 - val_counts/Oct4_loss: 1.9773
Epoch 8/100
5561/5561 [==============================] - 1s 142us/step - loss: 674.8869 - profile/Sox2_loss: 269.3763 - profile/Oct4_loss: 377.0702 - counts/Sox2_loss: 1.0037 - counts/Oct4_loss: 1.8404 - val_loss: 665.6921 - val_profile/Sox2_loss: 265.0239 - val_profile/Oct4_loss: 371.1480 - val_counts/Sox2_loss: 1.0111 - val_counts/Oct4_loss: 1.9410
Epoch 9/100
5561/5561 [==============================] - 1s 138us/step - loss: 672.0182 - profile/Sox2_loss: 268.3358 - profile/Oct4_loss: 375.5975 - counts/Sox2_loss: 0.9993 - counts/Oct4_loss: 1.8092 - val_loss: 663.8514 - val_profile/Sox2_loss: 264.1862 - val_profile/Oct4_loss: 370.3957 - val_counts/Sox2_loss: 1.0098 - val_counts/Oct4_loss: 1.9171
Epoch 10/100
5561/5561 [==============================] - 1s 143us/step - loss: 669.7962 - profile/Sox2_loss: 267.5871 - profile/Oct4_loss: 374.5216 - counts/Sox2_loss: 0.9923 - counts/Oct4_loss: 1.7765 - val_loss: 661.1955 - val_profile/Sox2_loss: 263.8456 - val_profile/Oct4_loss: 368.9675 - val_counts/Sox2_loss: 0.9928 - val_counts/Oct4_loss: 1.8454
Epoch 11/100
5561/5561 [==============================] - 1s 143us/step - loss: 667.0221 - profile/Sox2_loss: 266.8370 - profile/Oct4_loss: 373.2011 - counts/Sox2_loss: 0.9790 - counts/Oct4_loss: 1.7194 - val_loss: 661.4092 - val_profile/Sox2_loss: 263.7145 - val_profile/Oct4_loss: 369.2962 - val_counts/Sox2_loss: 0.9901 - val_counts/Oct4_loss: 1.8497
Epoch 12/100
5561/5561 [==============================] - 1s 145us/step - loss: 665.2298 - profile/Sox2_loss: 266.4547 - profile/Oct4_loss: 372.7772 - counts/Sox2_loss: 0.9566 - counts/Oct4_loss: 1.6432 - val_loss: 656.5736 - val_profile/Sox2_loss: 262.9196 - val_profile/Oct4_loss: 367.4937 - val_counts/Sox2_loss: 0.9450 - val_counts/Oct4_loss: 1.6710
Epoch 13/100
5561/5561 [==============================] - 1s 142us/step - loss: 662.8659 - profile/Sox2_loss: 265.9402 - profile/Oct4_loss: 372.2614 - counts/Sox2_loss: 0.9269 - counts/Oct4_loss: 1.5395 - val_loss: 657.6663 - val_profile/Sox2_loss: 263.6503 - val_profile/Oct4_loss: 368.1491 - val_counts/Sox2_loss: 0.9392 - val_counts/Oct4_loss: 1.6475
Epoch 14/100
5561/5561 [==============================] - 1s 137us/step - loss: 660.9163 - profile/Sox2_loss: 265.7454 - profile/Oct4_loss: 371.8555 - counts/Sox2_loss: 0.9032 - counts/Oct4_loss: 1.4283 - val_loss: 652.2147 - val_profile/Sox2_loss: 262.5748 - val_profile/Oct4_loss: 366.4732 - val_counts/Sox2_loss: 0.8781 - val_counts/Oct4_loss: 1.4385
Epoch 15/100
5561/5561 [==============================] - 1s 138us/step - loss: 657.8369 - profile/Sox2_loss: 265.1549 - profile/Oct4_loss: 371.0640 - counts/Sox2_loss: 0.8679 - counts/Oct4_loss: 1.2939 - val_loss: 652.4196 - val_profile/Sox2_loss: 262.4391 - val_profile/Oct4_loss: 366.0477 - val_counts/Sox2_loss: 0.9093 - val_counts/Oct4_loss: 1.4840
Epoch 16/100
5561/5561 [==============================] - 1s 149us/step - loss: 656.4947 - profile/Sox2_loss: 264.6679 - profile/Oct4_loss: 370.6945 - counts/Sox2_loss: 0.8603 - counts/Oct4_loss: 1.2529 - val_loss: 650.0787 - val_profile/Sox2_loss: 261.7504 - val_profile/Oct4_loss: 365.5641 - val_counts/Sox2_loss: 0.8829 - val_counts/Oct4_loss: 1.3935
Epoch 17/100
5561/5561 [==============================] - 1s 144us/step - loss: 655.9837 - profile/Sox2_loss: 264.2288 - profile/Oct4_loss: 370.4364 - counts/Sox2_loss: 0.8638 - counts/Oct4_loss: 1.2681 - val_loss: 652.8544 - val_profile/Sox2_loss: 261.9077 - val_profile/Oct4_loss: 365.4995 - val_counts/Sox2_loss: 0.9641 - val_counts/Oct4_loss: 1.5806
Epoch 18/100
5561/5561 [==============================] - 1s 143us/step - loss: 656.3376 - profile/Sox2_loss: 263.9934 - profile/Oct4_loss: 370.4965 - counts/Sox2_loss: 0.8776 - counts/Oct4_loss: 1.3071 - val_loss: 649.0483 - val_profile/Sox2_loss: 261.4256 - val_profile/Oct4_loss: 365.2357 - val_counts/Sox2_loss: 0.8678 - val_counts/Oct4_loss: 1.3709
Epoch 19/100
5561/5561 [==============================] - 1s 142us/step - loss: 654.2747 - profile/Sox2_loss: 263.6690 - profile/Oct4_loss: 369.9854 - counts/Sox2_loss: 0.8511 - counts/Oct4_loss: 1.2110 - val_loss: 649.0594 - val_profile/Sox2_loss: 261.4125 - val_profile/Oct4_loss: 365.2539 - val_counts/Sox2_loss: 0.8598 - val_counts/Oct4_loss: 1.3795
Epoch 20/100
5561/5561 [==============================] - 1s 141us/step - loss: 653.1843 - profile/Sox2_loss: 263.0770 - profile/Oct4_loss: 369.6362 - counts/Sox2_loss: 0.8466 - counts/Oct4_loss: 1.2005 - val_loss: 647.2415 - val_profile/Sox2_loss: 260.7196 - val_profile/Oct4_loss: 364.7562 - val_counts/Sox2_loss: 0.8472 - val_counts/Oct4_loss: 1.3293
Epoch 21/100
5561/5561 [==============================] - 1s 140us/step - loss: 653.3272 - profile/Sox2_loss: 262.6746 - profile/Oct4_loss: 369.6625 - counts/Sox2_loss: 0.8539 - counts/Oct4_loss: 1.2451 - val_loss: 650.1726 - val_profile/Sox2_loss: 261.0584 - val_profile/Oct4_loss: 365.3939 - val_counts/Sox2_loss: 0.9025 - val_counts/Oct4_loss: 1.4696
Epoch 22/100
5561/5561 [==============================] - 1s 141us/step - loss: 652.7307 - profile/Sox2_loss: 262.6056 - profile/Oct4_loss: 369.5839 - counts/Sox2_loss: 0.8521 - counts/Oct4_loss: 1.2020 - val_loss: 646.3634 - val_profile/Sox2_loss: 260.4200 - val_profile/Oct4_loss: 364.6710 - val_counts/Sox2_loss: 0.8469 - val_counts/Oct4_loss: 1.2803
Epoch 23/100
5561/5561 [==============================] - 1s 141us/step - loss: 651.0230 - profile/Sox2_loss: 261.9445 - profile/Oct4_loss: 369.1204 - counts/Sox2_loss: 0.8331 - counts/Oct4_loss: 1.1627 - val_loss: 646.0062 - val_profile/Sox2_loss: 260.1354 - val_profile/Oct4_loss: 364.5090 - val_counts/Sox2_loss: 0.8391 - val_counts/Oct4_loss: 1.2970
Epoch 24/100
5561/5561 [==============================] - 1s 151us/step - loss: 650.5344 - profile/Sox2_loss: 261.7289 - profile/Oct4_loss: 368.8925 - counts/Sox2_loss: 0.8290 - counts/Oct4_loss: 1.1623 - val_loss: 646.4804 - val_profile/Sox2_loss: 259.8761 - val_profile/Oct4_loss: 364.3715 - val_counts/Sox2_loss: 0.8664 - val_counts/Oct4_loss: 1.3569
Epoch 25/100
5561/5561 [==============================] - 1s 138us/step - loss: 649.8453 - profile/Sox2_loss: 261.4682 - profile/Oct4_loss: 368.6670 - counts/Sox2_loss: 0.8296 - counts/Oct4_loss: 1.1414 - val_loss: 646.4076 - val_profile/Sox2_loss: 260.5963 - val_profile/Oct4_loss: 364.4816 - val_counts/Sox2_loss: 0.8410 - val_counts/Oct4_loss: 1.2920
Epoch 26/100
5561/5561 [==============================] - 1s 141us/step - loss: 649.1630 - profile/Sox2_loss: 261.1932 - profile/Oct4_loss: 368.4076 - counts/Sox2_loss: 0.8192 - counts/Oct4_loss: 1.1370 - val_loss: 647.1062 - val_profile/Sox2_loss: 259.8813 - val_profile/Oct4_loss: 364.3948 - val_counts/Sox2_loss: 0.8708 - val_counts/Oct4_loss: 1.4122
Epoch 27/100
5561/5561 [==============================] - 1s 139us/step - loss: 648.5595 - profile/Sox2_loss: 260.9540 - profile/Oct4_loss: 368.2138 - counts/Sox2_loss: 0.8140 - counts/Oct4_loss: 1.1251 - val_loss: 645.9803 - val_profile/Sox2_loss: 259.8484 - val_profile/Oct4_loss: 364.2243 - val_counts/Sox2_loss: 0.8729 - val_counts/Oct4_loss: 1.3179
Epoch 28/100
5561/5561 [==============================] - 1s 140us/step - loss: 647.9295 - profile/Sox2_loss: 260.6528 - profile/Oct4_loss: 368.0184 - counts/Sox2_loss: 0.8084 - counts/Oct4_loss: 1.1174 - val_loss: 647.4447 - val_profile/Sox2_loss: 259.6509 - val_profile/Oct4_loss: 364.1011 - val_counts/Sox2_loss: 0.8822 - val_counts/Oct4_loss: 1.4871
Epoch 29/100
5561/5561 [==============================] - 1s 141us/step - loss: 647.5582 - profile/Sox2_loss: 260.5805 - profile/Oct4_loss: 367.9920 - counts/Sox2_loss: 0.8011 - counts/Oct4_loss: 1.0974 - val_loss: 645.6124 - val_profile/Sox2_loss: 259.6432 - val_profile/Oct4_loss: 364.8525 - val_counts/Sox2_loss: 0.8358 - val_counts/Oct4_loss: 1.2759
Epoch 30/100
5561/5561 [==============================] - 1s 140us/step - loss: 646.8070 - profile/Sox2_loss: 260.4437 - profile/Oct4_loss: 367.8710 - counts/Sox2_loss: 0.7897 - counts/Oct4_loss: 1.0596 - val_loss: 644.7841 - val_profile/Sox2_loss: 259.4430 - val_profile/Oct4_loss: 364.3710 - val_counts/Sox2_loss: 0.8276 - val_counts/Oct4_loss: 1.2694
Epoch 31/100
5561/5561 [==============================] - 1s 150us/step - loss: 646.6688 - profile/Sox2_loss: 260.2703 - profile/Oct4_loss: 367.7807 - counts/Sox2_loss: 0.7866 - counts/Oct4_loss: 1.0752 - val_loss: 644.4129 - val_profile/Sox2_loss: 259.2332 - val_profile/Oct4_loss: 364.2772 - val_counts/Sox2_loss: 0.8314 - val_counts/Oct4_loss: 1.2588
Epoch 32/100
5561/5561 [==============================] - 1s 147us/step - loss: 647.0625 - profile/Sox2_loss: 260.2153 - profile/Oct4_loss: 367.7033 - counts/Sox2_loss: 0.7997 - counts/Oct4_loss: 1.1147 - val_loss: 644.4601 - val_profile/Sox2_loss: 259.2357 - val_profile/Oct4_loss: 364.0821 - val_counts/Sox2_loss: 0.8204 - val_counts/Oct4_loss: 1.2938
Epoch 33/100
5561/5561 [==============================] - 1s 142us/step - loss: 646.3374 - profile/Sox2_loss: 260.2045 - profile/Oct4_loss: 367.4287 - counts/Sox2_loss: 0.7881 - counts/Oct4_loss: 1.0823 - val_loss: 643.4771 - val_profile/Sox2_loss: 259.0958 - val_profile/Oct4_loss: 363.7151 - val_counts/Sox2_loss: 0.8239 - val_counts/Oct4_loss: 1.2427
Epoch 34/100
5561/5561 [==============================] - 1s 143us/step - loss: 645.9404 - profile/Sox2_loss: 259.9618 - profile/Oct4_loss: 367.2447 - counts/Sox2_loss: 0.7858 - counts/Oct4_loss: 1.0876 - val_loss: 644.0175 - val_profile/Sox2_loss: 259.0713 - val_profile/Oct4_loss: 364.2430 - val_counts/Sox2_loss: 0.8283 - val_counts/Oct4_loss: 1.2420
Epoch 35/100
5561/5561 [==============================] - 1s 140us/step - loss: 646.1490 - profile/Sox2_loss: 260.0808 - profile/Oct4_loss: 367.3990 - counts/Sox2_loss: 0.7905 - counts/Oct4_loss: 1.0764 - val_loss: 644.1146 - val_profile/Sox2_loss: 259.0461 - val_profile/Oct4_loss: 363.8776 - val_counts/Sox2_loss: 0.8392 - val_counts/Oct4_loss: 1.2799
Epoch 36/100
5561/5561 [==============================] - 1s 144us/step - loss: 644.7566 - profile/Sox2_loss: 259.7835 - profile/Oct4_loss: 367.0205 - counts/Sox2_loss: 0.7672 - counts/Oct4_loss: 1.0280 - val_loss: 643.5260 - val_profile/Sox2_loss: 259.0089 - val_profile/Oct4_loss: 363.8105 - val_counts/Sox2_loss: 0.8187 - val_counts/Oct4_loss: 1.2519
Epoch 37/100
5561/5561 [==============================] - 1s 142us/step - loss: 644.3900 - profile/Sox2_loss: 259.6067 - profile/Oct4_loss: 366.8250 - counts/Sox2_loss: 0.7629 - counts/Oct4_loss: 1.0329 - val_loss: 644.0730 - val_profile/Sox2_loss: 259.0439 - val_profile/Oct4_loss: 363.9569 - val_counts/Sox2_loss: 0.8260 - val_counts/Oct4_loss: 1.2812
Epoch 38/100
5561/5561 [==============================] - 1s 138us/step - loss: 645.3824 - profile/Sox2_loss: 259.6383 - profile/Oct4_loss: 366.9207 - counts/Sox2_loss: 0.7816 - counts/Oct4_loss: 1.1007 - val_loss: 643.4739 - val_profile/Sox2_loss: 258.9350 - val_profile/Oct4_loss: 363.9020 - val_counts/Sox2_loss: 0.8163 - val_counts/Oct4_loss: 1.2474
Epoch 39/100
5561/5561 [==============================] - 1s 142us/step - loss: 644.3495 - profile/Sox2_loss: 259.5391 - profile/Oct4_loss: 366.6670 - counts/Sox2_loss: 0.7700 - counts/Oct4_loss: 1.0443 - val_loss: 643.1004 - val_profile/Sox2_loss: 259.0135 - val_profile/Oct4_loss: 363.6081 - val_counts/Sox2_loss: 0.8108 - val_counts/Oct4_loss: 1.2371
Epoch 40/100
5561/5561 [==============================] - 1s 139us/step - loss: 643.2975 - profile/Sox2_loss: 259.3842 - profile/Oct4_loss: 366.6264 - counts/Sox2_loss: 0.7410 - counts/Oct4_loss: 0.9877 - val_loss: 643.8331 - val_profile/Sox2_loss: 258.9730 - val_profile/Oct4_loss: 364.1977 - val_counts/Sox2_loss: 0.8269 - val_counts/Oct4_loss: 1.2393
Epoch 41/100
5561/5561 [==============================] - 1s 141us/step - loss: 644.9920 - profile/Sox2_loss: 259.5460 - profile/Oct4_loss: 366.8010 - counts/Sox2_loss: 0.7758 - counts/Oct4_loss: 1.0887 - val_loss: 644.3303 - val_profile/Sox2_loss: 259.2844 - val_profile/Oct4_loss: 364.0157 - val_counts/Sox2_loss: 0.8251 - val_counts/Oct4_loss: 1.2779
Epoch 42/100
5561/5561 [==============================] - 1s 143us/step - loss: 643.7163 - profile/Sox2_loss: 259.4487 - profile/Oct4_loss: 366.5987 - counts/Sox2_loss: 0.7507 - counts/Oct4_loss: 1.0162 - val_loss: 642.9670 - val_profile/Sox2_loss: 258.8030 - val_profile/Oct4_loss: 363.7975 - val_counts/Sox2_loss: 0.8051 - val_counts/Oct4_loss: 1.2316
Epoch 43/100
5561/5561 [==============================] - 1s 141us/step - loss: 642.7882 - profile/Sox2_loss: 259.2908 - profile/Oct4_loss: 366.3841 - counts/Sox2_loss: 0.7299 - counts/Oct4_loss: 0.9814 - val_loss: 643.2026 - val_profile/Sox2_loss: 258.9167 - val_profile/Oct4_loss: 363.6172 - val_counts/Sox2_loss: 0.8149 - val_counts/Oct4_loss: 1.2519
Epoch 44/100
5561/5561 [==============================] - 1s 144us/step - loss: 642.8367 - profile/Sox2_loss: 259.1346 - profile/Oct4_loss: 366.2595 - counts/Sox2_loss: 0.7337 - counts/Oct4_loss: 1.0105 - val_loss: 643.2406 - val_profile/Sox2_loss: 259.0736 - val_profile/Oct4_loss: 363.6932 - val_counts/Sox2_loss: 0.7997 - val_counts/Oct4_loss: 1.2477
Epoch 45/100
5561/5561 [==============================] - 1s 139us/step - loss: 642.8782 - profile/Sox2_loss: 259.2050 - profile/Oct4_loss: 366.2901 - counts/Sox2_loss: 0.7339 - counts/Oct4_loss: 1.0045 - val_loss: 642.7344 - val_profile/Sox2_loss: 258.7856 - val_profile/Oct4_loss: 363.5616 - val_counts/Sox2_loss: 0.8045 - val_counts/Oct4_loss: 1.2342
Epoch 46/100
5561/5561 [==============================] - 1s 139us/step - loss: 642.9312 - profile/Sox2_loss: 259.2078 - profile/Oct4_loss: 366.2723 - counts/Sox2_loss: 0.7347 - counts/Oct4_loss: 1.0104 - val_loss: 644.2177 - val_profile/Sox2_loss: 258.6007 - val_profile/Oct4_loss: 363.6781 - val_counts/Sox2_loss: 0.8437 - val_counts/Oct4_loss: 1.3502
Epoch 47/100
5561/5561 [==============================] - 1s 144us/step - loss: 643.3315 - profile/Sox2_loss: 259.1302 - profile/Oct4_loss: 366.3686 - counts/Sox2_loss: 0.7426 - counts/Oct4_loss: 1.0407 - val_loss: 643.9598 - val_profile/Sox2_loss: 258.9861 - val_profile/Oct4_loss: 364.6564 - val_counts/Sox2_loss: 0.8010 - val_counts/Oct4_loss: 1.2307
Epoch 48/100
5561/5561 [==============================] - 1s 143us/step - loss: 642.1726 - profile/Sox2_loss: 258.9850 - profile/Oct4_loss: 366.0466 - counts/Sox2_loss: 0.7252 - counts/Oct4_loss: 0.9889 - val_loss: 643.3012 - val_profile/Sox2_loss: 258.7832 - val_profile/Oct4_loss: 363.8290 - val_counts/Sox2_loss: 0.8224 - val_counts/Oct4_loss: 1.2465
Epoch 49/100
5561/5561 [==============================] - 1s 141us/step - loss: 641.3844 - profile/Sox2_loss: 258.8733 - profile/Oct4_loss: 365.7894 - counts/Sox2_loss: 0.7102 - counts/Oct4_loss: 0.9619 - val_loss: 643.5725 - val_profile/Sox2_loss: 258.8187 - val_profile/Oct4_loss: 364.1278 - val_counts/Sox2_loss: 0.8029 - val_counts/Oct4_loss: 1.2597
Epoch 50/100
5561/5561 [==============================] - 1s 140us/step - loss: 641.4855 - profile/Sox2_loss: 258.8605 - profile/Oct4_loss: 365.8449 - counts/Sox2_loss: 0.7077 - counts/Oct4_loss: 0.9703 - val_loss: 643.0985 - val_profile/Sox2_loss: 258.7356 - val_profile/Oct4_loss: 363.9948 - val_counts/Sox2_loss: 0.8040 - val_counts/Oct4_loss: 1.2328
In [24]:
from basepair.eval import evaluate
evaluate(model2, valid_nex[0], valid_nex[1])
Out[24]:
{'loss': 642.7344364239152,
 'profile/Sox2_loss': 258.78561213467026,
 'profile/Oct4_loss': 363.5616598918939,
 'counts/Sox2_loss': 0.8045008476119638,
 'counts/Oct4_loss': 1.2342151269285035}
In [137]:
def seq_multitask_newloss(filters=21,
                  conv1_kernel_size=21,
                  tconv_kernel_size=25,
                  n_dil_layers=6,
                  lr=0.004,
                  c_task_weight=100,
                  use_profile=True,
                  use_counts=True,
                  tasks=['sox2', 'oct4'],
                  outputs_per_task=2,
                  task_use_bias=False,
                  seq_len=201):  # TODO - automatically infer sequence length
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
      task_use_bias (bool or a list of bools): if True, a
        bias term is assumed to be provided at the input
    """
    # TODO - split the body of this model into multiple subparts:
    # - encoder
    # - profile_decoder
    # - profile_decoder_w_bias
    # - counts_decoder
    # - counts_decoder_w_bias
    if isinstance(outputs_per_task, int):
        outputs_per_task = [outputs_per_task for i in range(len(tasks))]
    else:
        assert len(tasks) == len(outputs_per_task)
    if isinstance(task_use_bias, bool):
        task_use_bias = [task_use_bias for i in range(len(tasks))]
    else:
        assert len(tasks) == len(task_use_bias)

    # TODO - build the reverse complement symmetry into the model
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    bias_profile_inputs = {task: kl.Input(shape=(seq_len, outputs_per_task[i]), name=f"bias/profile/{task}")
                           for i, task in enumerate(tasks) if task_use_bias[i]}
    bias_counts_inputs = [kl.Input(shape=(outputs_per_task[i], ), name=f"bias/counts/{task}")
                          for i, task in enumerate(tasks) if task_use_bias[i]]
    prev_layers = [first_conv]
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = kl.add(prev_layers)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(sum(outputs_per_task), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    out = kl.Reshape((-1, sum(outputs_per_task)))(x)
    # batch x seqlen x tasks*2 array
    # need another array of the same length

    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        # TODO - use a different loss function for the same profiles
        start_idx = np.cumsum([0] + outputs_per_task[:-1])
        end_idx = np.cumsum(outputs_per_task)

        def get_output_name(task):
            if task in bias_profile_inputs:
                return "lambda/profile/" + task
            else:
                return "profile/" + task
        output = [kl.Lambda(lambda x, i, sidx, eidx: x[:, :, sidx:eidx],
                            output_shape=(seq_len, outputs_per_task[i]),
                            name=get_output_name(task),
                            arguments={"i": i, "sidx": start_idx[i], "eidx": end_idx[i]})(out)
                  for i, task in enumerate(tasks)]
        for i, task in enumerate(tasks):
            if task in bias_profile_inputs:
                output_with_bias = kl.concatenate([output[i],
                                                   bias_profile_inputs[task]], axis=-1)  # batch x seqlen x (2+2)
                output[i] = kl.Conv1D(outputs_per_task[i],
                                      1,
                                      name="profile/" + task)(output_with_bias)

        outputs += output
        losses += [basepair.losses.get(f"mc_multinomial_nll_{nt}") for nt in outputs_per_task]
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        if bias_counts_inputs:
            pooled = kl.concatenate([pooled] + bias_counts_inputs, axis=-1)  # add bias as additional features
        counts = [kl.Dense(outputs_per_task[i], name="counts/" + task)(pooled)
                  for i, task in enumerate(tasks)]
        outputs += counts
        losses += ["mae"] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    model = Model([inp] + list(bias_profile_inputs.values()) + bias_counts_inputs, outputs)
    model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
    return model
In [138]:
mfn = "seq_multitask_newloss"
mkwargs = dict(filters=32, 
               conv1_kernel_size=21,
               tconv_kernel_size=25,
               n_dil_layers=6,
               use_profile=use_profile,
               use_counts=use_counts,
               c_task_weight=10,
               lr=0.004)
In [139]:
import basepair
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train_nex[0], 
                    train_nex[1],
          batch_size=256, 
          epochs=100,
          validation_data=valid_nex[:2],
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll, 
                                             "SpatialLifetimeSparsity": SpatialLifetimeSparsity})
Train on 5561 samples, validate on 1884 samples
Epoch 1/100
5561/5561 [==============================] - 8s 1ms/step - loss: 742.2577 - profile/Sox2_loss: 296.2518 - profile/Oct4_loss: 425.9215 - counts/Sox2_loss: 0.8264 - counts/Oct4_loss: 1.1820 - val_loss: 689.1856 - val_profile/Sox2_loss: 278.4897 - val_profile/Oct4_loss: 391.4255 - val_counts/Sox2_loss: 0.7732 - val_counts/Oct4_loss: 1.1538
Epoch 2/100
5561/5561 [==============================] - 1s 160us/step - loss: 697.1065 - profile/Sox2_loss: 281.9431 - profile/Oct4_loss: 396.2219 - counts/Sox2_loss: 0.7760 - counts/Oct4_loss: 1.1181 - val_loss: 680.4753 - val_profile/Sox2_loss: 274.8912 - val_profile/Oct4_loss: 386.2270 - val_counts/Sox2_loss: 0.7755 - val_counts/Oct4_loss: 1.1602
Epoch 3/100
5561/5561 [==============================] - 1s 160us/step - loss: 690.6432 - profile/Sox2_loss: 279.2718 - profile/Oct4_loss: 392.4481 - counts/Sox2_loss: 0.7758 - counts/Oct4_loss: 1.1165 - val_loss: 675.8279 - val_profile/Sox2_loss: 272.8461 - val_profile/Oct4_loss: 383.7597 - val_counts/Sox2_loss: 0.7746 - val_counts/Oct4_loss: 1.1476
Epoch 4/100
5561/5561 [==============================] - 1s 160us/step - loss: 685.3920 - profile/Sox2_loss: 276.9115 - profile/Oct4_loss: 389.5713 - counts/Sox2_loss: 0.7754 - counts/Oct4_loss: 1.1155 - val_loss: 670.7238 - val_profile/Sox2_loss: 270.7903 - val_profile/Oct4_loss: 380.7447 - val_counts/Sox2_loss: 0.7738 - val_counts/Oct4_loss: 1.1451
Epoch 5/100
5561/5561 [==============================] - 1s 160us/step - loss: 679.3926 - profile/Sox2_loss: 274.4484 - profile/Oct4_loss: 386.0467 - counts/Sox2_loss: 0.7761 - counts/Oct4_loss: 1.1137 - val_loss: 665.9683 - val_profile/Sox2_loss: 268.7482 - val_profile/Oct4_loss: 378.0335 - val_counts/Sox2_loss: 0.7768 - val_counts/Oct4_loss: 1.1419
Epoch 6/100
5561/5561 [==============================] - 1s 159us/step - loss: 675.0967 - profile/Sox2_loss: 273.0213 - profile/Oct4_loss: 383.2072 - counts/Sox2_loss: 0.7759 - counts/Oct4_loss: 1.1109 - val_loss: 662.7390 - val_profile/Sox2_loss: 267.8483 - val_profile/Oct4_loss: 375.7364 - val_counts/Sox2_loss: 0.7740 - val_counts/Oct4_loss: 1.1415
Epoch 7/100
5561/5561 [==============================] - 1s 161us/step - loss: 671.8429 - profile/Sox2_loss: 271.7429 - profile/Oct4_loss: 381.2481 - counts/Sox2_loss: 0.7753 - counts/Oct4_loss: 1.1099 - val_loss: 660.0199 - val_profile/Sox2_loss: 266.8182 - val_profile/Oct4_loss: 374.0404 - val_counts/Sox2_loss: 0.7735 - val_counts/Oct4_loss: 1.1426
Epoch 8/100
5561/5561 [==============================] - 1s 161us/step - loss: 668.0757 - profile/Sox2_loss: 270.3013 - profile/Oct4_loss: 378.9649 - counts/Sox2_loss: 0.7740 - counts/Oct4_loss: 1.1069 - val_loss: 658.0019 - val_profile/Sox2_loss: 266.0656 - val_profile/Oct4_loss: 372.7948 - val_counts/Sox2_loss: 0.7744 - val_counts/Oct4_loss: 1.1398
Epoch 9/100
5561/5561 [==============================] - 1s 158us/step - loss: 665.4285 - profile/Sox2_loss: 269.2124 - profile/Oct4_loss: 377.4461 - counts/Sox2_loss: 0.7735 - counts/Oct4_loss: 1.1035 - val_loss: 655.2673 - val_profile/Sox2_loss: 264.9415 - val_profile/Oct4_loss: 371.2359 - val_counts/Sox2_loss: 0.7717 - val_counts/Oct4_loss: 1.1373
Epoch 10/100
5561/5561 [==============================] - 1s 161us/step - loss: 663.2125 - profile/Sox2_loss: 268.2533 - profile/Oct4_loss: 376.2110 - counts/Sox2_loss: 0.7728 - counts/Oct4_loss: 1.1020 - val_loss: 654.1410 - val_profile/Sox2_loss: 264.7488 - val_profile/Oct4_loss: 370.3339 - val_counts/Sox2_loss: 0.7714 - val_counts/Oct4_loss: 1.1344
Epoch 11/100
5561/5561 [==============================] - 1s 160us/step - loss: 661.6653 - profile/Sox2_loss: 267.6083 - profile/Oct4_loss: 375.3373 - counts/Sox2_loss: 0.7722 - counts/Oct4_loss: 1.0998 - val_loss: 652.6350 - val_profile/Sox2_loss: 264.1517 - val_profile/Oct4_loss: 369.4531 - val_counts/Sox2_loss: 0.7721 - val_counts/Oct4_loss: 1.1309
Epoch 12/100
5561/5561 [==============================] - 1s 160us/step - loss: 659.8739 - profile/Sox2_loss: 266.8691 - profile/Oct4_loss: 374.3250 - counts/Sox2_loss: 0.7712 - counts/Oct4_loss: 1.0968 - val_loss: 651.1165 - val_profile/Sox2_loss: 263.4699 - val_profile/Oct4_loss: 368.6147 - val_counts/Sox2_loss: 0.7708 - val_counts/Oct4_loss: 1.1324
Epoch 13/100
5561/5561 [==============================] - 1s 158us/step - loss: 658.3667 - profile/Sox2_loss: 266.1954 - profile/Oct4_loss: 373.5290 - counts/Sox2_loss: 0.7697 - counts/Oct4_loss: 1.0945 - val_loss: 650.6006 - val_profile/Sox2_loss: 263.0993 - val_profile/Oct4_loss: 368.5548 - val_counts/Sox2_loss: 0.7684 - val_counts/Oct4_loss: 1.1262
Epoch 14/100
5561/5561 [==============================] - 1s 158us/step - loss: 657.1740 - profile/Sox2_loss: 265.6693 - profile/Oct4_loss: 372.9029 - counts/Sox2_loss: 0.7690 - counts/Oct4_loss: 1.0912 - val_loss: 649.4128 - val_profile/Sox2_loss: 262.5916 - val_profile/Oct4_loss: 367.8636 - val_counts/Sox2_loss: 0.7677 - val_counts/Oct4_loss: 1.1281
Epoch 15/100
5561/5561 [==============================] - 1s 163us/step - loss: 655.8277 - profile/Sox2_loss: 265.0451 - profile/Oct4_loss: 372.2552 - counts/Sox2_loss: 0.7659 - counts/Oct4_loss: 1.0868 - val_loss: 649.3726 - val_profile/Sox2_loss: 262.4880 - val_profile/Oct4_loss: 367.9007 - val_counts/Sox2_loss: 0.7679 - val_counts/Oct4_loss: 1.1305
Epoch 16/100
5561/5561 [==============================] - 1s 160us/step - loss: 654.9063 - profile/Sox2_loss: 264.6028 - profile/Oct4_loss: 371.8047 - counts/Sox2_loss: 0.7648 - counts/Oct4_loss: 1.0850 - val_loss: 647.6851 - val_profile/Sox2_loss: 261.8810 - val_profile/Oct4_loss: 367.0492 - val_counts/Sox2_loss: 0.7637 - val_counts/Oct4_loss: 1.1118
Epoch 17/100
5561/5561 [==============================] - 1s 162us/step - loss: 654.1534 - profile/Sox2_loss: 264.2273 - profile/Oct4_loss: 371.4882 - counts/Sox2_loss: 0.7626 - counts/Oct4_loss: 1.0812 - val_loss: 647.4460 - val_profile/Sox2_loss: 261.9295 - val_profile/Oct4_loss: 366.8463 - val_counts/Sox2_loss: 0.7605 - val_counts/Oct4_loss: 1.1065
Epoch 18/100
5561/5561 [==============================] - 1s 155us/step - loss: 653.2772 - profile/Sox2_loss: 263.8458 - profile/Oct4_loss: 371.0654 - counts/Sox2_loss: 0.7602 - counts/Oct4_loss: 1.0764 - val_loss: 646.8169 - val_profile/Sox2_loss: 261.8199 - val_profile/Oct4_loss: 366.3469 - val_counts/Sox2_loss: 0.7614 - val_counts/Oct4_loss: 1.1036
Epoch 19/100
5561/5561 [==============================] - 1s 163us/step - loss: 652.8543 - profile/Sox2_loss: 263.6535 - profile/Oct4_loss: 370.9406 - counts/Sox2_loss: 0.7580 - counts/Oct4_loss: 1.0680 - val_loss: 645.7909 - val_profile/Sox2_loss: 261.1532 - val_profile/Oct4_loss: 366.1305 - val_counts/Sox2_loss: 0.7565 - val_counts/Oct4_loss: 1.0942
Epoch 20/100
5561/5561 [==============================] - 1s 160us/step - loss: 651.7556 - profile/Sox2_loss: 263.0767 - profile/Oct4_loss: 370.5082 - counts/Sox2_loss: 0.7546 - counts/Oct4_loss: 1.0625 - val_loss: 645.1581 - val_profile/Sox2_loss: 261.0209 - val_profile/Oct4_loss: 365.7640 - val_counts/Sox2_loss: 0.7531 - val_counts/Oct4_loss: 1.0842
Epoch 21/100
5561/5561 [==============================] - 1s 160us/step - loss: 650.8419 - profile/Sox2_loss: 262.6993 - profile/Oct4_loss: 370.0873 - counts/Sox2_loss: 0.7526 - counts/Oct4_loss: 1.0530 - val_loss: 644.4642 - val_profile/Sox2_loss: 260.7284 - val_profile/Oct4_loss: 365.4288 - val_counts/Sox2_loss: 0.7504 - val_counts/Oct4_loss: 1.0803
Epoch 22/100
5561/5561 [==============================] - 1s 162us/step - loss: 649.6208 - profile/Sox2_loss: 262.2337 - profile/Oct4_loss: 369.4882 - counts/Sox2_loss: 0.7465 - counts/Oct4_loss: 1.0434 - val_loss: 644.0904 - val_profile/Sox2_loss: 260.3538 - val_profile/Oct4_loss: 365.3700 - val_counts/Sox2_loss: 0.7498 - val_counts/Oct4_loss: 1.0869
Epoch 23/100
5561/5561 [==============================] - 1s 163us/step - loss: 649.4003 - profile/Sox2_loss: 262.1167 - profile/Oct4_loss: 369.4156 - counts/Sox2_loss: 0.7457 - counts/Oct4_loss: 1.0411 - val_loss: 643.4704 - val_profile/Sox2_loss: 260.3151 - val_profile/Oct4_loss: 365.1761 - val_counts/Sox2_loss: 0.7428 - val_counts/Oct4_loss: 1.0551
Epoch 24/100
5561/5561 [==============================] - 1s 157us/step - loss: 648.4630 - profile/Sox2_loss: 261.7800 - profile/Oct4_loss: 369.0758 - counts/Sox2_loss: 0.7395 - counts/Oct4_loss: 1.0212 - val_loss: 643.0744 - val_profile/Sox2_loss: 260.1975 - val_profile/Oct4_loss: 364.9920 - val_counts/Sox2_loss: 0.7429 - val_counts/Oct4_loss: 1.0456
Epoch 25/100
5561/5561 [==============================] - 1s 160us/step - loss: 648.2276 - profile/Sox2_loss: 261.6898 - profile/Oct4_loss: 369.0500 - counts/Sox2_loss: 0.7362 - counts/Oct4_loss: 1.0125 - val_loss: 642.9298 - val_profile/Sox2_loss: 260.3057 - val_profile/Oct4_loss: 364.9426 - val_counts/Sox2_loss: 0.7375 - val_counts/Oct4_loss: 1.0306
Epoch 26/100
5561/5561 [==============================] - 1s 163us/step - loss: 647.7037 - profile/Sox2_loss: 261.5056 - profile/Oct4_loss: 368.8003 - counts/Sox2_loss: 0.7353 - counts/Oct4_loss: 1.0045 - val_loss: 642.4764 - val_profile/Sox2_loss: 260.0836 - val_profile/Oct4_loss: 364.8066 - val_counts/Sox2_loss: 0.7334 - val_counts/Oct4_loss: 1.0253
Epoch 27/100
5561/5561 [==============================] - 1s 161us/step - loss: 646.8936 - profile/Sox2_loss: 261.1578 - profile/Oct4_loss: 368.5036 - counts/Sox2_loss: 0.7321 - counts/Oct4_loss: 0.9911 - val_loss: 642.2561 - val_profile/Sox2_loss: 259.9337 - val_profile/Oct4_loss: 364.7484 - val_counts/Sox2_loss: 0.7494 - val_counts/Oct4_loss: 1.0080
Epoch 28/100
5561/5561 [==============================] - 1s 159us/step - loss: 646.8193 - profile/Sox2_loss: 261.0489 - profile/Oct4_loss: 368.4135 - counts/Sox2_loss: 0.7403 - counts/Oct4_loss: 0.9954 - val_loss: 641.2318 - val_profile/Sox2_loss: 259.5668 - val_profile/Oct4_loss: 364.3768 - val_counts/Sox2_loss: 0.7334 - val_counts/Oct4_loss: 0.9954
Epoch 29/100
5561/5561 [==============================] - 1s 156us/step - loss: 645.6657 - profile/Sox2_loss: 260.5936 - profile/Oct4_loss: 368.1195 - counts/Sox2_loss: 0.7279 - counts/Oct4_loss: 0.9673 - val_loss: 641.8491 - val_profile/Sox2_loss: 259.5149 - val_profile/Oct4_loss: 365.1305 - val_counts/Sox2_loss: 0.7290 - val_counts/Oct4_loss: 0.9913
Epoch 30/100
5561/5561 [==============================] - 1s 156us/step - loss: 645.8690 - profile/Sox2_loss: 260.6735 - profile/Oct4_loss: 368.1085 - counts/Sox2_loss: 0.7320 - counts/Oct4_loss: 0.9767 - val_loss: 640.5910 - val_profile/Sox2_loss: 259.3059 - val_profile/Oct4_loss: 363.9527 - val_counts/Sox2_loss: 0.7406 - val_counts/Oct4_loss: 0.9926
Epoch 31/100
5561/5561 [==============================] - 1s 155us/step - loss: 644.7787 - profile/Sox2_loss: 260.3547 - profile/Oct4_loss: 367.5295 - counts/Sox2_loss: 0.7269 - counts/Oct4_loss: 0.9626 - val_loss: 640.5860 - val_profile/Sox2_loss: 259.6152 - val_profile/Oct4_loss: 364.1194 - val_counts/Sox2_loss: 0.7251 - val_counts/Oct4_loss: 0.9601
Epoch 32/100
5561/5561 [==============================] - 1s 158us/step - loss: 644.1967 - profile/Sox2_loss: 260.1386 - profile/Oct4_loss: 367.4967 - counts/Sox2_loss: 0.7212 - counts/Oct4_loss: 0.9350 - val_loss: 640.6749 - val_profile/Sox2_loss: 259.3332 - val_profile/Oct4_loss: 364.5604 - val_counts/Sox2_loss: 0.7264 - val_counts/Oct4_loss: 0.9517
Epoch 33/100
5561/5561 [==============================] - 1s 161us/step - loss: 643.8447 - profile/Sox2_loss: 260.0580 - profile/Oct4_loss: 367.3575 - counts/Sox2_loss: 0.7173 - counts/Oct4_loss: 0.9257 - val_loss: 639.6552 - val_profile/Sox2_loss: 259.1554 - val_profile/Oct4_loss: 363.6540 - val_counts/Sox2_loss: 0.7371 - val_counts/Oct4_loss: 0.9475
Epoch 34/100
5561/5561 [==============================] - 1s 161us/step - loss: 643.6425 - profile/Sox2_loss: 259.9314 - profile/Oct4_loss: 367.3332 - counts/Sox2_loss: 0.7182 - counts/Oct4_loss: 0.9196 - val_loss: 639.5473 - val_profile/Sox2_loss: 259.1789 - val_profile/Oct4_loss: 363.8052 - val_counts/Sox2_loss: 0.7228 - val_counts/Oct4_loss: 0.9335
Epoch 35/100
5561/5561 [==============================] - 1s 163us/step - loss: 643.0237 - profile/Sox2_loss: 259.7065 - profile/Oct4_loss: 367.0747 - counts/Sox2_loss: 0.7142 - counts/Oct4_loss: 0.9100 - val_loss: 640.1337 - val_profile/Sox2_loss: 259.4353 - val_profile/Oct4_loss: 364.0771 - val_counts/Sox2_loss: 0.7239 - val_counts/Oct4_loss: 0.9382
Epoch 36/100
5561/5561 [==============================] - 1s 162us/step - loss: 642.7619 - profile/Sox2_loss: 259.5764 - profile/Oct4_loss: 366.8718 - counts/Sox2_loss: 0.7190 - counts/Oct4_loss: 0.9124 - val_loss: 639.5448 - val_profile/Sox2_loss: 259.3162 - val_profile/Oct4_loss: 363.7716 - val_counts/Sox2_loss: 0.7224 - val_counts/Oct4_loss: 0.9233
Epoch 37/100
5561/5561 [==============================] - 1s 160us/step - loss: 642.6316 - profile/Sox2_loss: 259.5002 - profile/Oct4_loss: 366.6740 - counts/Sox2_loss: 0.7216 - counts/Oct4_loss: 0.9241 - val_loss: 641.6525 - val_profile/Sox2_loss: 259.2008 - val_profile/Oct4_loss: 364.2727 - val_counts/Sox2_loss: 0.7797 - val_counts/Oct4_loss: 1.0382
Epoch 38/100
5561/5561 [==============================] - 1s 161us/step - loss: 642.6588 - profile/Sox2_loss: 259.5484 - profile/Oct4_loss: 366.7286 - counts/Sox2_loss: 0.7198 - counts/Oct4_loss: 0.9184 - val_loss: 639.3696 - val_profile/Sox2_loss: 259.0768 - val_profile/Oct4_loss: 363.8936 - val_counts/Sox2_loss: 0.7217 - val_counts/Oct4_loss: 0.9182
Epoch 39/100
5561/5561 [==============================] - 1s 165us/step - loss: 641.7992 - profile/Sox2_loss: 259.2581 - profile/Oct4_loss: 366.5428 - counts/Sox2_loss: 0.7112 - counts/Oct4_loss: 0.8886 - val_loss: 639.5094 - val_profile/Sox2_loss: 258.9877 - val_profile/Oct4_loss: 364.0029 - val_counts/Sox2_loss: 0.7275 - val_counts/Oct4_loss: 0.9244
Epoch 40/100
5561/5561 [==============================] - 1s 169us/step - loss: 641.9835 - profile/Sox2_loss: 259.1934 - profile/Oct4_loss: 366.4385 - counts/Sox2_loss: 0.7249 - counts/Oct4_loss: 0.9102 - val_loss: 639.1305 - val_profile/Sox2_loss: 259.0530 - val_profile/Oct4_loss: 363.6297 - val_counts/Sox2_loss: 0.7235 - val_counts/Oct4_loss: 0.9212
Epoch 41/100
5561/5561 [==============================] - 1s 162us/step - loss: 641.3177 - profile/Sox2_loss: 259.0838 - profile/Oct4_loss: 366.2968 - counts/Sox2_loss: 0.7105 - counts/Oct4_loss: 0.8832 - val_loss: 638.7986 - val_profile/Sox2_loss: 258.9844 - val_profile/Oct4_loss: 363.3596 - val_counts/Sox2_loss: 0.7224 - val_counts/Oct4_loss: 0.9230
Epoch 42/100
5561/5561 [==============================] - 1s 163us/step - loss: 641.0427 - profile/Sox2_loss: 259.0406 - profile/Oct4_loss: 366.1243 - counts/Sox2_loss: 0.7088 - counts/Oct4_loss: 0.8790 - val_loss: 639.1451 - val_profile/Sox2_loss: 259.1575 - val_profile/Oct4_loss: 363.4015 - val_counts/Sox2_loss: 0.7304 - val_counts/Oct4_loss: 0.9282
Epoch 43/100
5561/5561 [==============================] - 1s 164us/step - loss: 640.8725 - profile/Sox2_loss: 259.0406 - profile/Oct4_loss: 365.8991 - counts/Sox2_loss: 0.7110 - counts/Oct4_loss: 0.8823 - val_loss: 641.2305 - val_profile/Sox2_loss: 259.0127 - val_profile/Oct4_loss: 363.9074 - val_counts/Sox2_loss: 0.7700 - val_counts/Oct4_loss: 1.0610
Epoch 44/100
5561/5561 [==============================] - 1s 156us/step - loss: 641.2508 - profile/Sox2_loss: 258.9904 - profile/Oct4_loss: 366.0558 - counts/Sox2_loss: 0.7166 - counts/Oct4_loss: 0.9039 - val_loss: 638.5251 - val_profile/Sox2_loss: 258.8918 - val_profile/Oct4_loss: 363.3486 - val_counts/Sox2_loss: 0.7231 - val_counts/Oct4_loss: 0.9054
Epoch 45/100
5561/5561 [==============================] - 1s 157us/step - loss: 640.6159 - profile/Sox2_loss: 258.7922 - profile/Oct4_loss: 365.9072 - counts/Sox2_loss: 0.7115 - counts/Oct4_loss: 0.8801 - val_loss: 638.7049 - val_profile/Sox2_loss: 258.6975 - val_profile/Oct4_loss: 363.5081 - val_counts/Sox2_loss: 0.7270 - val_counts/Oct4_loss: 0.9229
Epoch 46/100
5561/5561 [==============================] - 1s 159us/step - loss: 640.1167 - profile/Sox2_loss: 258.6941 - profile/Oct4_loss: 365.6384 - counts/Sox2_loss: 0.7076 - counts/Oct4_loss: 0.8709 - val_loss: 639.6338 - val_profile/Sox2_loss: 259.2368 - val_profile/Oct4_loss: 363.5453 - val_counts/Sox2_loss: 0.7269 - val_counts/Oct4_loss: 0.9583
Epoch 47/100
5561/5561 [==============================] - 1s 160us/step - loss: 639.8231 - profile/Sox2_loss: 258.6519 - profile/Oct4_loss: 365.4629 - counts/Sox2_loss: 0.7081 - counts/Oct4_loss: 0.8627 - val_loss: 638.7694 - val_profile/Sox2_loss: 258.9100 - val_profile/Oct4_loss: 363.5431 - val_counts/Sox2_loss: 0.7244 - val_counts/Oct4_loss: 0.9072
Epoch 48/100
5561/5561 [==============================] - 1s 159us/step - loss: 639.3890 - profile/Sox2_loss: 258.5663 - profile/Oct4_loss: 365.2723 - counts/Sox2_loss: 0.7009 - counts/Oct4_loss: 0.8541 - val_loss: 638.8391 - val_profile/Sox2_loss: 258.8668 - val_profile/Oct4_loss: 363.3858 - val_counts/Sox2_loss: 0.7314 - val_counts/Oct4_loss: 0.9272
Epoch 49/100
5561/5561 [==============================] - 1s 160us/step - loss: 639.5096 - profile/Sox2_loss: 258.4528 - profile/Oct4_loss: 365.3357 - counts/Sox2_loss: 0.7048 - counts/Oct4_loss: 0.8674 - val_loss: 640.6230 - val_profile/Sox2_loss: 259.0020 - val_profile/Oct4_loss: 364.0112 - val_counts/Sox2_loss: 0.7528 - val_counts/Oct4_loss: 1.0082
In [140]:
evaluate(model, valid_nex[0], valid_nex[1])
Out[140]:
{'loss': 638.5251498536193,
 'profile/Sox2_loss': 258.891815639099,
 'profile/Oct4_loss': 363.348608450272,
 'counts/Sox2_loss': 0.7230639781668434,
 'counts/Oct4_loss': 0.9054083862122456}
In [141]:
from basepair.plots import regression_eval
In [142]:
y_pred_new = model.predict(test_nex[0])
y_pred_old = model2.predict(test_nex[0])
In [143]:
print("OLD SOX2")
regression_eval(test_nex[1][2].mean(-1), y_pred_old[ds.task2idx("Sox2", 'counts')].mean(-1))
OLD SOX2
In [144]:
print("NEW SOX2")
regression_eval(test_nex[1][2].mean(-1), y_pred_new[ds.task2idx("Sox2", 'counts')].mean(-1))
NEW SOX2
In [145]:
print("OLD OCT4")
regression_eval(test_nex[1][3].mean(-1), y_pred_old[ds.task2idx("Oct4", 'counts')].mean(-1))
OLD OCT4
In [146]:
print("NEW OCT4")
regression_eval(test_nex[1][3].mean(-1), y_pred_new[ds.task2idx("Oct4", 'counts')].mean(-1))
NEW OCT4