Goal

  • make a model that will also predict the absolute counts of the regions

TODO

  • [x] histogram of the counts
    • what's the right distribution
      • make a q-q plot with poisson, negative bionomial
  • [x] train a model predicting also the counts
    • focus only on sox2
    • make the correlation plot between these two

Modeling pooled counts

  • [x] implement the data function which returns the pooled signal
  • [x] allow the model function to also work on this dataset

Next steps

  • [ ] run the same thing for the multi-task model
  • [ ] setup the sacred experiment for it
    • run the evaluation overnight
      • run 3 runs for each model. Select the best valid loss
        • report the test performance
  • [ ] model the counts using the negative binomial distribution
In [6]:
import pandas as pd
import numpy as np
from pybedtools import BedTool
from basepair.config import get_data_dir, create_tf_session
from tqdm import tqdm
from concise.preprocessing import encodeDNA
from basepair.datasets import get_sox2_data
from basepair.plots import plot_loss
import pyBigWig
from basepair.math import softmax
import matplotlib.pyplot as plt
from basepair import samplers

ddir = get_data_dir()
In [91]:
create_tf_session(1)
Out[91]:
<tensorflow.python.client.session.Session at 0x7f1eefa056d8>
In [101]:
from basepair.datasets import *
from basepair.models import *
from basepair.plots import *
In [5]:
train, valid, test = seq_inp_exo_out()  # default sox2 dataste
In [97]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
from basepair.models import seq_mutlitask
In [98]:
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
In [19]:
def get_model(mfn, mkwargs):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipnexus/exp/models/multi-task"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(datetime.datetime.now()).replace(" ", "::")
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    return eval(mfn)(**mkwargs), name, ckp_file

histogram of the counts

  • what's the right distribution
    • make a q-q plot with poisson, negative bionomial
In [9]:
train[1].shape
Out[9]:
(5561, 201, 2)
In [13]:
plt.hist(train[1].sum(1).sum(1), bins=100)
plt.xlabel("Total counts");
In [16]:
counts = train[1].sum(1).sum(1)
In [14]:
plt.scatter(train[1].sum(1)[:,0], train[1].sum(1)[:,1])
plt.xlabel("Forward strand counts");
plt.ylabel("Reverse strand counts");
In [21]:
from statsmodels.graphics.gofplots import qqplot
from scipy.stats import poisson, nbinom, norm
In [30]:
import scipy.stats as stats
In [45]:
counts.mean()
Out[45]:
153.76802
In [68]:
stats.probplot(counts, dist="norm", plot=plt);
In [67]:
# Log-normal seems to model the data pretty well
stats.probplot(np.log(counts+1), dist="norm", plot=plt);
In [63]:
stats.probplot(counts, sparams=(counts.mean()), dist="poisson", plot=plt);
In [54]:
stats.probplot(counts, sparams=(0.2, 0.001), dist="nbinom", plot=plt);

train a model predicting also the counts

  • focus only on sox2
  • make the correlation plot between these two
In [70]:
from keras.models import Sequential
In [128]:
import concise.metrics as ce
import concise.eval_metrics as cem
In [120]:
from concise.metrics import var_explained
In [302]:
filters = 32
def single_output(filters=32, 
                  conv1_kernel_size=21,
                  dropout=0.1,
                  seq_len=201,
                  pool_type='avg',
                  lr=0.004):
    def get_pool(pool_type):
        if pool_type=='max':
            return kl.MaxPool1D()
        elif pool_type =='avg':
            return kl.AveragePooling1D()
    model = Sequential([
        kl.Conv1D(filters, conv1_kernel_size, activation='relu', 
                  input_shape=(seq_len,4), padding='same'),
        kl.Conv1D(filters, 1, activation='relu', padding='same'),
        get_pool(pool_type),
        kl.Conv1D(2*filters, 7, activation='relu', padding='same'),
        get_pool(pool_type),
        kl.Conv1D(2*filters, 7, activation='relu', padding='same'),
        #get_pool(pool_type),
        kl.GlobalAveragePooling1D(),
        #kl.Dense(8*filters, activation='relu'),
        #kl.Dropout(dropout),
        kl.Dense(1)
    ])
    model.compile(Adam(lr=lr), 
                  loss="mse",
                  metrics=[ce.var_explained]
                 )
    return model
In [303]:
def get_model(mfn, mkwargs):
    """Get the model"""
    import datetime
    mdir = f"{ddir}/processed/chipnexus/exp/models/count-output"
    name = mfn + "_" + \
            ",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
            "." + str(datetime.datetime.now()).replace(" ", "::")
    !mkdir -p {mdir}
    ckp_file = f"{mdir}/{name}.h5"
    return eval(mfn)(**mkwargs), name, ckp_file
In [312]:
# hyper-parameters
mfn = "single_output"
mkwargs = dict(filters=64, 
               conv1_kernel_size=21,
               dropout=0.5,
               pool_type='max',
               lr=0.004)
In [313]:
valid_counts = (valid[0], np.log(valid[1].sum(1).sum(1)+1))
In [314]:
test_counts = (test[0], np.log(test[1].sum(1).sum(1)+1))
In [315]:
# best valid so far: 108238.6558
model, name, ckp_file = get_model(mfn, mkwargs)
history = model.fit(train[0], 
                    np.log(train[1].sum(1).sum(1)+1),
          batch_size=256, 
          epochs=100,
          validation_data=(valid[0], np.log(valid[1].sum(1).sum(1)+1)),
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file)
Train on 5561 samples, validate on 1884 samples
Epoch 1/100
5561/5561 [==============================] - 7s 1ms/step - loss: 5.1914 - var_explained: -0.0084 - val_loss: 1.1616 - val_var_explained: -0.0331
Epoch 2/100
5561/5561 [==============================] - 0s 87us/step - loss: 0.6862 - var_explained: -0.0056 - val_loss: 0.4661 - val_var_explained: -0.0132
Epoch 3/100
5561/5561 [==============================] - 0s 71us/step - loss: 0.4950 - var_explained: -0.0016 - val_loss: 0.4896 - val_var_explained: -0.0055
Epoch 4/100
5561/5561 [==============================] - 0s 70us/step - loss: 0.4870 - var_explained: 5.1517e-04 - val_loss: 0.4689 - val_var_explained: -0.0030
Epoch 5/100
5561/5561 [==============================] - 0s 78us/step - loss: 0.4669 - var_explained: 0.0023 - val_loss: 0.4856 - val_var_explained: -0.0021
Epoch 6/100
5561/5561 [==============================] - 0s 80us/step - loss: 0.4739 - var_explained: 0.0041 - val_loss: 0.4655 - val_var_explained: -8.1162e-04
Epoch 7/100
5561/5561 [==============================] - 0s 77us/step - loss: 0.4657 - var_explained: 0.0060 - val_loss: 0.4650 - val_var_explained: -3.0170e-05
Epoch 8/100
5561/5561 [==============================] - 0s 90us/step - loss: 0.4682 - var_explained: 0.0076 - val_loss: 0.4856 - val_var_explained: 0.0014
Epoch 9/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4747 - var_explained: 0.0094 - val_loss: 0.4638 - val_var_explained: 0.0020
Epoch 10/100
5561/5561 [==============================] - 0s 74us/step - loss: 0.4664 - var_explained: 0.0118 - val_loss: 0.4729 - val_var_explained: 0.0021
Epoch 11/100
5561/5561 [==============================] - 0s 70us/step - loss: 0.4614 - var_explained: 0.0139 - val_loss: 0.4830 - val_var_explained: 0.0035
Epoch 12/100
5561/5561 [==============================] - 0s 76us/step - loss: 0.4750 - var_explained: 0.0166 - val_loss: 0.4662 - val_var_explained: 0.0033
Epoch 13/100
5561/5561 [==============================] - 0s 70us/step - loss: 0.4660 - var_explained: 0.0201 - val_loss: 0.4611 - val_var_explained: 0.0039
Epoch 14/100
5561/5561 [==============================] - 0s 77us/step - loss: 0.4622 - var_explained: 0.0244 - val_loss: 0.4666 - val_var_explained: 0.0020
Epoch 15/100
5561/5561 [==============================] - 1s 95us/step - loss: 0.4565 - var_explained: 0.0304 - val_loss: 0.4557 - val_var_explained: -0.0048
Epoch 16/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4560 - var_explained: 0.0386 - val_loss: 0.4736 - val_var_explained: -0.0098
Epoch 17/100
5561/5561 [==============================] - 0s 71us/step - loss: 0.4725 - var_explained: 0.0423 - val_loss: 0.4610 - val_var_explained: -0.0307
Epoch 18/100
5561/5561 [==============================] - 0s 74us/step - loss: 0.4640 - var_explained: 0.0500 - val_loss: 0.4494 - val_var_explained: -0.0386
Epoch 19/100
5561/5561 [==============================] - 0s 83us/step - loss: 0.4503 - var_explained: 0.0557 - val_loss: 0.4465 - val_var_explained: -0.0617
Epoch 20/100
5561/5561 [==============================] - 0s 69us/step - loss: 0.4529 - var_explained: 0.0628 - val_loss: 0.4514 - val_var_explained: -0.0501
Epoch 21/100
5561/5561 [==============================] - 0s 78us/step - loss: 0.4532 - var_explained: 0.0684 - val_loss: 0.4699 - val_var_explained: -0.0777
Epoch 22/100
5561/5561 [==============================] - 1s 91us/step - loss: 0.4425 - var_explained: 0.0745 - val_loss: 0.4511 - val_var_explained: -0.1280
Epoch 23/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4431 - var_explained: 0.0800 - val_loss: 0.4390 - val_var_explained: -0.0889
Epoch 24/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4412 - var_explained: 0.0885 - val_loss: 0.4486 - val_var_explained: -0.0772
Epoch 25/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4275 - var_explained: 0.0977 - val_loss: 0.4654 - val_var_explained: -0.2303
Epoch 26/100
5561/5561 [==============================] - 0s 80us/step - loss: 0.4197 - var_explained: 0.1065 - val_loss: 0.4328 - val_var_explained: -0.1248
Epoch 27/100
5561/5561 [==============================] - 0s 71us/step - loss: 0.4171 - var_explained: 0.1176 - val_loss: 0.4324 - val_var_explained: -0.2651
Epoch 28/100
5561/5561 [==============================] - 0s 82us/step - loss: 0.4481 - var_explained: 0.1261 - val_loss: 0.5877 - val_var_explained: -0.1834
Epoch 29/100
5561/5561 [==============================] - 0s 88us/step - loss: 0.4425 - var_explained: 0.1336 - val_loss: 0.4710 - val_var_explained: -0.3490
Epoch 30/100
5561/5561 [==============================] - 0s 72us/step - loss: 0.4358 - var_explained: 0.1425 - val_loss: 0.4568 - val_var_explained: -0.2747
Epoch 31/100
5561/5561 [==============================] - 0s 74us/step - loss: 0.4193 - var_explained: 0.1530 - val_loss: 0.4254 - val_var_explained: -0.2709
Epoch 32/100
5561/5561 [==============================] - 0s 79us/step - loss: 0.3953 - var_explained: 0.1650 - val_loss: 0.4246 - val_var_explained: -0.3283
Epoch 33/100
5561/5561 [==============================] - 0s 74us/step - loss: 0.4009 - var_explained: 0.1798 - val_loss: 0.4912 - val_var_explained: -0.4428
Epoch 34/100
5561/5561 [==============================] - 0s 74us/step - loss: 0.3890 - var_explained: 0.1959 - val_loss: 0.4940 - val_var_explained: -0.5467
Epoch 35/100
5561/5561 [==============================] - 0s 75us/step - loss: 0.3978 - var_explained: 0.2105 - val_loss: 0.4182 - val_var_explained: -0.6147
Epoch 36/100
5561/5561 [==============================] - 0s 89us/step - loss: 0.3762 - var_explained: 0.2359 - val_loss: 0.4113 - val_var_explained: -0.4747
Epoch 37/100
5561/5561 [==============================] - 0s 73us/step - loss: 0.3620 - var_explained: 0.2511 - val_loss: 0.4894 - val_var_explained: -0.4662
Epoch 38/100
5561/5561 [==============================] - 0s 73us/step - loss: 0.3693 - var_explained: 0.2698 - val_loss: 0.4055 - val_var_explained: -0.8390
Epoch 39/100
5561/5561 [==============================] - 0s 77us/step - loss: 0.3448 - var_explained: 0.2918 - val_loss: 0.4115 - val_var_explained: -1.0317
Epoch 40/100
5561/5561 [==============================] - 0s 77us/step - loss: 0.3323 - var_explained: 0.3061 - val_loss: 0.4142 - val_var_explained: -1.1133
Epoch 41/100
5561/5561 [==============================] - 0s 82us/step - loss: 0.3244 - var_explained: 0.3195 - val_loss: 0.4149 - val_var_explained: -1.0482
Epoch 42/100
5561/5561 [==============================] - 0s 86us/step - loss: 0.3302 - var_explained: 0.3335 - val_loss: 0.4154 - val_var_explained: -1.0183
Epoch 43/100
5561/5561 [==============================] - 0s 79us/step - loss: 0.3118 - var_explained: 0.3436 - val_loss: 0.4279 - val_var_explained: -1.4607
In [316]:
dfh = pd.DataFrame(history.history)
plt.figure(figsize=(8,2))
plt.subplot(121)
plt.plot(dfh.loss, label="loss")
plt.plot(dfh.val_loss, label="val_loss")
plt.legend()
plt.xlabel("epoch")
plt.subplot(122)
plt.plot(dfh.var_explained, label="var_explained")
plt.plot(dfh.val_var_explained, label="val_var_explained")
plt.legend()
plt.xlabel("epoch");
In [317]:
y_pred = model.predict(valid_counts[0])
In [318]:
regression_eval(valid_counts[1], y_pred[:,0])
In [311]:
cem.var_explained(valid_counts[1], y_pred)
Out[311]:
-0.06671547889709473

Train a joint model

In [627]:
def seq_dense_count(filters=21, 
                    conv1_kernel_size=21,
                    tconv_kernel_size=25,
                    n_dil_layers=6,
                    use_profile=True,
                    seq_len=201,
                    profile_pool=None,
                    count_weight=100,
                    lr=0.004):
    """
    Dense
    """
    # TODO - build the reverse complement symmetry into the model
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters, kernel_size=conv1_kernel_size, padding='same', 
                   #kernel_initializer = ci.PSSMKernelInitializer(pwm_list, stddev=0, add_noise_before_Pwm2Pssm=False),
                   #bias_initializer = 'zeros',
                   activation='relu')(inp)
    
    prev_layers = [first_conv]
    for i in range(1, n_dil_layers+1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = kl.add(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)
    # De-conv
    combined_conv = kl.add(prev_layers)
    if profile_pool is not None:
        combined_conv = kl.AveragePooling1D(pool_size=profile_pool)(combined_conv)
    # Bottleneck layer
    #combined_conv = kl.Conv1D(3, kernel_size=1, padding='same', activation='relu')(combined_conv)
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(2, kernel_size=(tconv_kernel_size, 1), padding='same')(x)
        #kl.Conv2DTranspose(32, kernel_size=(7, 1), padding='same', activation='relu'),
        #kl.Conv2DTranspose(2, kernel_size=(3, 1), padding='same'),
    out = kl.Reshape((-1, 2), name='profile')(x)
    
    pooled = kl.GlobalAvgPool1D()(combined_conv)
    #hidden = kl.Dense(1, pooled)
    count_out = kl.Dense(2, name='count')(pooled)
    if use_profile:
        model = Model(inp, [out, count_out])
        model.compile(Adam(lr=lr), 
                  loss=[twochannel_multinomial_nll, 'mse'],
                  loss_weights=[1, count_weight])
    else:
        model = Model(inp, count_out)
        model.compile(Adam(lr=lr), 
                      loss='mse',)
    return model
In [628]:
# hyper-parameters
mfn = "seq_dense_count"
mkwargs = dict(filters=21, 
               conv1_kernel_size=21,
               tconv_kernel_size=25,
               n_dil_layers=6,
               seq_len=201,
               profile_pool=None,
               use_profile=False,
               lr=0.004)
In [629]:
valid_counts = (valid[0], np.log(valid[1].sum(1).sum(1)+1))
In [630]:
test_counts = (test[0], np.log(test[1].sum(1).sum(1)+1))
In [631]:
model, name, ckp_file = get_model(mfn, mkwargs)
history = model.fit(train[0], 
                    [train[1], np.log(train[1].sum(1)+1)],
          batch_size=256, 
          epochs=100,
          validation_data=(valid[0], [np.log(valid[1].sum(1).sum(1)+1)]),
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-631-2def576d1faa> in <module>()
      7           callbacks=[EarlyStopping(patience=5),
      8                      History(),
----> 9                      ModelCheckpoint(ckp_file, save_best_only=True)]
     10          )
     11 # get the best model

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1628             sample_weight=sample_weight,
   1629             class_weight=class_weight,
-> 1630             batch_size=batch_size)
   1631         # Prepare validation data.
   1632         do_validation = False

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
   1474                                     self._feed_input_shapes,
   1475                                     check_batch_axis=False,
-> 1476                                     exception_prefix='input')
   1477         y = _standardize_input_data(y, self._feed_output_names,
   1478                                     output_shapes,

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    121                             ': expected ' + names[i] + ' to have shape ' +
    122                             str(shape) + ' but got array with shape ' +
--> 123                             str(data_shape))
    124     return data
    125 

ValueError: Error when checking input: expected seq to have shape (201, 4) but got array with shape (200, 4)
In [242]:
y_pred = model.predict(valid[0])
yc_pred = y_pred[:,0]
yc_true = np.log(valid[1].sum(1).sum(1)+1)
In [243]:
cem.var_explained(yc_true, yc_pred)
Out[243]:
0.12870138883590698
In [245]:
regression_eval(yc_true, yc_pred)
In [229]:
regression_eval(yc_true, yc_pred)

Make the plot

  • train the profile model on a coarser and coarser profile
    • report the prediction accuracy
In [320]:
x = train[1]
In [321]:
x.shape
Out[321]:
(5561, 201, 2)
In [338]:
from basepair.preproc import bin_counts
In [324]:
coarsen = bin_counts(x,3)
In [325]:
coarsen.shape
Out[325]:
(5561, 67, 2)
In [339]:
plt.plot(x[-1,:,0])
Out[339]:
[<matplotlib.lines.Line2D at 0x7f1c75a47e48>]
In [340]:
plt.plot(coarsen[-1,:,0])
Out[340]:
[<matplotlib.lines.Line2D at 0x7f1c7561f0b8>]

Train the model on the coarsen signal

In [361]:
200/4
Out[361]:
50.0
In [369]:
200/8
Out[369]:
25.0
In [734]:
# hyper-parameters
mfn = "seq_dense_count"
bin_size=None
mkwargs = dict(filters=21, 
               conv1_kernel_size=21,
               tconv_kernel_size=25,
               n_dil_layers=6,
               seq_len=200,
               profile_pool=bin_size,
               use_profile=True,
               count_weight=100,
               lr=0.004)
In [735]:
train, valid, test = seq_inp_exo_out(truncate_len=200, bin_size=bin_size)
In [736]:
valid_counts = (valid[0], np.log(valid[1].sum(1).sum(1)+1))
In [737]:
test_counts = (test[0], np.log(test[1].sum(1).sum(1)+1))
In [738]:
model, name, ckp_file = get_model(mfn, mkwargs)
history = model.fit(train[0], 
                    [train[1], np.log(train[1].sum(1)+1)],
          batch_size=256, 
          epochs=100,
          validation_data=(valid[0], [valid[1], np.log(valid[1].sum(1)+1)]),
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file)
Train on 5561 samples, validate on 1884 samples
Epoch 1/100
5561/5561 [==============================] - 24s 4ms/step - loss: 581.9729 - profile_loss: 324.5300 - count_loss: 2.5744 - val_loss: 378.7302 - val_profile_loss: 295.2163 - val_count_loss: 0.8351
Epoch 2/100
5561/5561 [==============================] - 1s 128us/step - loss: 354.9713 - profile_loss: 297.8666 - count_loss: 0.5710 - val_loss: 337.9158 - val_profile_loss: 286.3280 - val_count_loss: 0.5159
Epoch 3/100
5561/5561 [==============================] - 1s 126us/step - loss: 339.6143 - profile_loss: 290.2381 - count_loss: 0.4938 - val_loss: 331.1505 - val_profile_loss: 282.0630 - val_count_loss: 0.4909
Epoch 4/100
5561/5561 [==============================] - 1s 123us/step - loss: 335.1150 - profile_loss: 286.8514 - count_loss: 0.4826 - val_loss: 328.9046 - val_profile_loss: 279.9710 - val_count_loss: 0.4893
Epoch 5/100
5561/5561 [==============================] - 1s 134us/step - loss: 332.8500 - profile_loss: 284.9278 - count_loss: 0.4792 - val_loss: 327.5398 - val_profile_loss: 278.6295 - val_count_loss: 0.4891
Epoch 6/100
5561/5561 [==============================] - 1s 116us/step - loss: 331.6904 - profile_loss: 283.6524 - count_loss: 0.4804 - val_loss: 325.4729 - val_profile_loss: 277.3233 - val_count_loss: 0.4815
Epoch 7/100
5561/5561 [==============================] - 1s 115us/step - loss: 330.6200 - profile_loss: 282.7402 - count_loss: 0.4788 - val_loss: 324.7494 - val_profile_loss: 276.5879 - val_count_loss: 0.4816
Epoch 8/100
5561/5561 [==============================] - 1s 124us/step - loss: 329.8408 - profile_loss: 281.9537 - count_loss: 0.4789 - val_loss: 324.3670 - val_profile_loss: 276.0028 - val_count_loss: 0.4836
Epoch 9/100
5561/5561 [==============================] - 1s 129us/step - loss: 329.0383 - profile_loss: 281.3367 - count_loss: 0.4770 - val_loss: 323.7850 - val_profile_loss: 275.7869 - val_count_loss: 0.4800
Epoch 10/100
5561/5561 [==============================] - 1s 114us/step - loss: 328.6522 - profile_loss: 280.9290 - count_loss: 0.4772 - val_loss: 323.0098 - val_profile_loss: 275.1021 - val_count_loss: 0.4791
Epoch 11/100
5561/5561 [==============================] - 1s 124us/step - loss: 328.0022 - profile_loss: 280.4106 - count_loss: 0.4759 - val_loss: 323.2358 - val_profile_loss: 274.7449 - val_count_loss: 0.4849
Epoch 12/100
5561/5561 [==============================] - 1s 123us/step - loss: 327.7168 - profile_loss: 280.0264 - count_loss: 0.4769 - val_loss: 322.1925 - val_profile_loss: 274.3336 - val_count_loss: 0.4786
Epoch 13/100
5561/5561 [==============================] - 1s 139us/step - loss: 327.1838 - profile_loss: 279.4705 - count_loss: 0.4771 - val_loss: 322.0539 - val_profile_loss: 273.7784 - val_count_loss: 0.4828
Epoch 14/100
5561/5561 [==============================] - 1s 116us/step - loss: 326.4604 - profile_loss: 278.9650 - count_loss: 0.4750 - val_loss: 322.2030 - val_profile_loss: 273.4818 - val_count_loss: 0.4872
Epoch 15/100
5561/5561 [==============================] - 1s 126us/step - loss: 325.9706 - profile_loss: 278.4192 - count_loss: 0.4755 - val_loss: 320.9968 - val_profile_loss: 273.2228 - val_count_loss: 0.4777
Epoch 16/100
5561/5561 [==============================] - 1s 125us/step - loss: 325.2576 - profile_loss: 277.8898 - count_loss: 0.4737 - val_loss: 319.6207 - val_profile_loss: 272.1487 - val_count_loss: 0.4747
Epoch 17/100
5561/5561 [==============================] - 1s 138us/step - loss: 324.5662 - profile_loss: 276.9869 - count_loss: 0.4758 - val_loss: 318.6510 - val_profile_loss: 271.3586 - val_count_loss: 0.4729
Epoch 18/100
5561/5561 [==============================] - 1s 122us/step - loss: 322.8960 - profile_loss: 276.1257 - count_loss: 0.4677 - val_loss: 318.9686 - val_profile_loss: 271.1058 - val_count_loss: 0.4786
Epoch 19/100
5561/5561 [==============================] - 1s 120us/step - loss: 322.7893 - profile_loss: 275.5935 - count_loss: 0.4720 - val_loss: 318.0213 - val_profile_loss: 270.4937 - val_count_loss: 0.4753
Epoch 20/100
5561/5561 [==============================] - 1s 125us/step - loss: 321.5505 - profile_loss: 274.9535 - count_loss: 0.4660 - val_loss: 317.3012 - val_profile_loss: 270.1762 - val_count_loss: 0.4712
Epoch 21/100
5561/5561 [==============================] - 1s 136us/step - loss: 320.9144 - profile_loss: 274.3661 - count_loss: 0.4655 - val_loss: 316.1316 - val_profile_loss: 269.3739 - val_count_loss: 0.4676
Epoch 22/100
5561/5561 [==============================] - 1s 116us/step - loss: 321.0694 - profile_loss: 273.9612 - count_loss: 0.4711 - val_loss: 316.6534 - val_profile_loss: 269.1817 - val_count_loss: 0.4747
Epoch 23/100
5561/5561 [==============================] - 1s 127us/step - loss: 321.1701 - profile_loss: 273.4214 - count_loss: 0.4775 - val_loss: 316.9469 - val_profile_loss: 268.9387 - val_count_loss: 0.4801
Epoch 24/100
5561/5561 [==============================] - 1s 126us/step - loss: 319.3404 - profile_loss: 272.9238 - count_loss: 0.4642 - val_loss: 314.9400 - val_profile_loss: 268.2347 - val_count_loss: 0.4671
Epoch 25/100
5561/5561 [==============================] - 1s 140us/step - loss: 318.3941 - profile_loss: 272.5035 - count_loss: 0.4589 - val_loss: 316.7414 - val_profile_loss: 267.8628 - val_count_loss: 0.4888
Epoch 26/100
5561/5561 [==============================] - 1s 119us/step - loss: 318.3852 - profile_loss: 271.8039 - count_loss: 0.4658 - val_loss: 314.4808 - val_profile_loss: 267.3152 - val_count_loss: 0.4717
Epoch 27/100
5561/5561 [==============================] - 1s 124us/step - loss: 317.7832 - profile_loss: 271.3184 - count_loss: 0.4646 - val_loss: 312.7866 - val_profile_loss: 266.8961 - val_count_loss: 0.4589
Epoch 28/100
5561/5561 [==============================] - 1s 124us/step - loss: 316.7158 - profile_loss: 270.8478 - count_loss: 0.4587 - val_loss: 315.8774 - val_profile_loss: 266.9353 - val_count_loss: 0.4894
Epoch 29/100
5561/5561 [==============================] - 1s 139us/step - loss: 316.4495 - profile_loss: 270.3043 - count_loss: 0.4615 - val_loss: 313.9079 - val_profile_loss: 266.0692 - val_count_loss: 0.4784
Epoch 30/100
5561/5561 [==============================] - 1s 122us/step - loss: 315.2683 - profile_loss: 269.9679 - count_loss: 0.4530 - val_loss: 311.4959 - val_profile_loss: 265.5375 - val_count_loss: 0.4596
Epoch 31/100
5561/5561 [==============================] - 1s 134us/step - loss: 313.5458 - profile_loss: 269.2550 - count_loss: 0.4429 - val_loss: 310.7917 - val_profile_loss: 265.0950 - val_count_loss: 0.4570
Epoch 32/100
5561/5561 [==============================] - 1s 123us/step - loss: 312.8685 - profile_loss: 268.7878 - count_loss: 0.4408 - val_loss: 309.1058 - val_profile_loss: 264.6470 - val_count_loss: 0.4446
Epoch 33/100
5561/5561 [==============================] - 1s 139us/step - loss: 312.3862 - profile_loss: 268.4468 - count_loss: 0.4394 - val_loss: 312.6747 - val_profile_loss: 264.7383 - val_count_loss: 0.4794
Epoch 34/100
5561/5561 [==============================] - 1s 122us/step - loss: 311.5082 - profile_loss: 268.0025 - count_loss: 0.4351 - val_loss: 308.0171 - val_profile_loss: 264.5704 - val_count_loss: 0.4345
Epoch 35/100
5561/5561 [==============================] - 1s 127us/step - loss: 311.1898 - profile_loss: 267.6225 - count_loss: 0.4357 - val_loss: 307.6581 - val_profile_loss: 263.7171 - val_count_loss: 0.4394
Epoch 36/100
5561/5561 [==============================] - 1s 155us/step - loss: 309.9959 - profile_loss: 267.1746 - count_loss: 0.4282 - val_loss: 305.0013 - val_profile_loss: 263.2737 - val_count_loss: 0.4173
Epoch 37/100
5561/5561 [==============================] - 1s 111us/step - loss: 308.2199 - profile_loss: 266.8584 - count_loss: 0.4136 - val_loss: 305.6155 - val_profile_loss: 263.0672 - val_count_loss: 0.4255
Epoch 38/100
5561/5561 [==============================] - 1s 119us/step - loss: 307.4238 - profile_loss: 266.5029 - count_loss: 0.4092 - val_loss: 303.3342 - val_profile_loss: 262.7605 - val_count_loss: 0.4057
Epoch 39/100
5561/5561 [==============================] - 1s 125us/step - loss: 307.8091 - profile_loss: 266.2349 - count_loss: 0.4157 - val_loss: 302.8293 - val_profile_loss: 262.6381 - val_count_loss: 0.4019
Epoch 40/100
5561/5561 [==============================] - 1s 120us/step - loss: 304.7989 - profile_loss: 266.0380 - count_loss: 0.3876 - val_loss: 301.8704 - val_profile_loss: 262.2395 - val_count_loss: 0.3963
Epoch 41/100
5561/5561 [==============================] - 1s 137us/step - loss: 304.7946 - profile_loss: 265.7924 - count_loss: 0.3900 - val_loss: 305.2180 - val_profile_loss: 261.9581 - val_count_loss: 0.4326
Epoch 42/100
5561/5561 [==============================] - 1s 115us/step - loss: 303.6025 - profile_loss: 265.2957 - count_loss: 0.3831 - val_loss: 301.5889 - val_profile_loss: 261.7186 - val_count_loss: 0.3987
Epoch 43/100
5561/5561 [==============================] - 1s 123us/step - loss: 302.9854 - profile_loss: 265.2677 - count_loss: 0.3772 - val_loss: 301.5679 - val_profile_loss: 261.8282 - val_count_loss: 0.3974
Epoch 44/100
5561/5561 [==============================] - 1s 123us/step - loss: 302.1772 - profile_loss: 265.0005 - count_loss: 0.3718 - val_loss: 301.2587 - val_profile_loss: 261.5682 - val_count_loss: 0.3969
Epoch 45/100
5561/5561 [==============================] - 1s 134us/step - loss: 302.7714 - profile_loss: 264.7575 - count_loss: 0.3801 - val_loss: 300.6915 - val_profile_loss: 261.3687 - val_count_loss: 0.3932
Epoch 46/100
5561/5561 [==============================] - 1s 122us/step - loss: 301.4153 - profile_loss: 264.4950 - count_loss: 0.3692 - val_loss: 300.1476 - val_profile_loss: 261.1966 - val_count_loss: 0.3895
Epoch 47/100
5561/5561 [==============================] - 1s 123us/step - loss: 300.8368 - profile_loss: 264.3322 - count_loss: 0.3650 - val_loss: 298.9299 - val_profile_loss: 261.0135 - val_count_loss: 0.3792
Epoch 48/100
5561/5561 [==============================] - 1s 124us/step - loss: 300.7554 - profile_loss: 264.1722 - count_loss: 0.3658 - val_loss: 301.9275 - val_profile_loss: 260.9415 - val_count_loss: 0.4099
Epoch 49/100
5561/5561 [==============================] - 1s 133us/step - loss: 299.5900 - profile_loss: 263.9988 - count_loss: 0.3559 - val_loss: 298.9610 - val_profile_loss: 261.1381 - val_count_loss: 0.3782
Epoch 50/100
5561/5561 [==============================] - 1s 122us/step - loss: 299.2379 - profile_loss: 263.9963 - count_loss: 0.3524 - val_loss: 298.5308 - val_profile_loss: 261.1621 - val_count_loss: 0.3737
Epoch 51/100
5561/5561 [==============================] - 1s 127us/step - loss: 298.1016 - profile_loss: 263.9937 - count_loss: 0.3411 - val_loss: 298.5316 - val_profile_loss: 260.9006 - val_count_loss: 0.3763
Epoch 52/100
5561/5561 [==============================] - 1s 127us/step - loss: 298.0668 - profile_loss: 263.6853 - count_loss: 0.3438 - val_loss: 298.0477 - val_profile_loss: 260.6708 - val_count_loss: 0.3738
Epoch 53/100
5561/5561 [==============================] - 1s 132us/step - loss: 300.0721 - profile_loss: 263.7202 - count_loss: 0.3635 - val_loss: 297.9732 - val_profile_loss: 260.7079 - val_count_loss: 0.3727
Epoch 54/100
5561/5561 [==============================] - 1s 121us/step - loss: 298.1221 - profile_loss: 263.5162 - count_loss: 0.3461 - val_loss: 297.8576 - val_profile_loss: 260.4001 - val_count_loss: 0.3746
Epoch 55/100
5561/5561 [==============================] - 1s 120us/step - loss: 296.6148 - profile_loss: 263.3362 - count_loss: 0.3328 - val_loss: 297.2974 - val_profile_loss: 260.2850 - val_count_loss: 0.3701
Epoch 56/100
5561/5561 [==============================] - 1s 123us/step - loss: 297.0446 - profile_loss: 263.3350 - count_loss: 0.3371 - val_loss: 298.8881 - val_profile_loss: 260.4195 - val_count_loss: 0.3847
Epoch 57/100
5561/5561 [==============================] - 1s 135us/step - loss: 296.3586 - profile_loss: 263.2219 - count_loss: 0.3314 - val_loss: 298.1859 - val_profile_loss: 260.3842 - val_count_loss: 0.3780
Epoch 58/100
5561/5561 [==============================] - 1s 114us/step - loss: 296.0716 - profile_loss: 263.3076 - count_loss: 0.3276 - val_loss: 299.0662 - val_profile_loss: 260.1043 - val_count_loss: 0.3896
Epoch 59/100
5561/5561 [==============================] - 1s 114us/step - loss: 295.8462 - profile_loss: 263.1631 - count_loss: 0.3268 - val_loss: 301.0598 - val_profile_loss: 260.3188 - val_count_loss: 0.4074
Epoch 60/100
5561/5561 [==============================] - 1s 118us/step - loss: 295.1769 - profile_loss: 263.0260 - count_loss: 0.3215 - val_loss: 301.1162 - val_profile_loss: 260.0158 - val_count_loss: 0.4110
In [739]:
y_pred = model.predict(valid[0])
yc_pred = y_pred[1][:,0]
yc_true = np.log(valid[1].sum(1)[:,0]+1)
In [740]:
cem.var_explained(yc_true, yc_pred)
Out[740]:
0.2363957166671753
In [741]:
regression_eval(yc_true, yc_pred)
In [760]:
y_pred = model.predict(test[0])
yc_pred = y_pred[1][:,0]
yc_true = np.log(test[1].sum(1).sum(1)+1)
In [ ]:
cem.var_explained(yc_true, yc_pred)
In [745]:
ckp_file
Out[745]:
'/users/avsec/workspace/basepair/basepair/../data/processed/chipnexus/exp/models/count-output/seq_dense_count_filters=21,conv1_kernel_size=21,tconv_kernel_size=25,n_dil_layers=6,seq_len=200,profile_pool=None,use_profile=True,count_weight=100,lr=0.004.2018-05-18::22:06:05.059394.h5'
In [744]:
regression_eval(yc_true, yc_pred)
In [654]:
test[1].shape
Out[654]:
(1951, 200, 2)
In [657]:
y_pred[0][:,:,0].shape
Out[657]:
(1951, 200)
In [662]:
y_pred[1].shape
Out[662]:
(1951, 2)
In [665]:
y_pred[0][:,:,0].shape
Out[665]:
(1951, 200)
In [757]:
binsize=1
regression_eval(np.log(1+np.ravel(bin_counts(test[1], binsize=binsize)[:,:,0])), 
                np.log(1+np.ravel((bin_counts(softmax(y_pred[0]), binsize=binsize)[:,:,0])*
                          y_pred[1][:,:1])))
---------------------------------------------------------------------------
AxisError                                 Traceback (most recent call last)
<ipython-input-757-34b78974bfae> in <module>()
      1 binsize=1
      2 regression_eval(np.log(1+np.ravel(bin_counts(test[1], binsize=binsize)[:,:,0])), 
----> 3                 np.log(1+np.ravel((bin_counts(softmax(y_pred[0]), binsize=binsize)[:,:,0])*
      4                           y_pred[1][:,:1])))

~/workspace/basepair/basepair/math.py in softmax(x)
      4 def softmax(x):
      5     """Compute softmax values for each sets of scores in x."""
----> 6     e_x = np.exp(x - np.max(x, axis=-2, keepdims=True))
      7     return e_x / e_x.sum(axis=-2, keepdims=True)

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/numpy/core/fromnumeric.py in amax(a, axis, out, keepdims)
   2315             pass
   2316         else:
-> 2317             return amax(axis=axis, out=out, **kwargs)
   2318 
   2319     return _methods._amax(a, axis=axis,

~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/numpy/core/_methods.py in _amax(a, axis, out, keepdims)
     24 # small reductions
     25 def _amax(a, axis=None, out=None, keepdims=False):
---> 26     return umr_maximum(a, axis, None, out, keepdims)
     27 
     28 def _amin(a, axis=None, out=None, keepdims=False):

AxisError: axis -2 is out of bounds for array of dimension 0
In [755]:
from scipy.stats import pearsonr, spearmanr
In [802]:
binsizes = [1, 2, 4, 10, 25, 50, 100, 200]
perf = []
for binsize in binsizes:
    yc_true = np.log(1+np.ravel(bin_counts(test[1], binsize=binsize)[:,:,0]))
    yc_pred = np.log(1+np.ravel((bin_counts(softmax(y_pred[0]), binsize=binsize)[:,:,0])*
                          y_pred[1][:,:1]))
    perf.append([pearsonr(yc_true, yc_pred)[0], spearmanr(yc_true, yc_pred)[0]])
In [803]:
df = pd.DataFrame(perf, columns=['pearson','spearman']).assign(binsize=binsizes)
In [811]:
plt.figure(figsize=(3,2))
plt.semilogx(df.binsize, df.spearman,'-o' )
plt.ylabel("R_spearman")
plt.xlabel("Bin size")
plt.xticks(binsizes, binsizes);
In [ ]:
# Show just the plot of different binning
In [710]:
binsize=200//25
regression_eval(np.log(1+np.ravel(bin_counts(test[1], binsize=binsize)[:,:,0])), 
                np.log(1+np.ravel((bin_counts(softmax(y_pred[0]), binsize=binsize)[:,:,0])*
                          y_pred[1][:,:1])))

The correlation performance is roughly the same if we train the model on pooled signal (of 25)

In [698]:
R = [0.45, 0.43, 0.39, 0.42, 0.38, 0.42, 0.42, 0.42, 0.38, 0.27]
pool_size = [0, 2, 4, 8, 10, 20, 25, 50, 100, 200]
In [579]:
plt.semilogx(np.array(pool_size)+1, R, "-o")
plt.xlabel("Pool Size")
plt.ylabel("R_spearman")
Out[579]:
Text(0,0.5,'R_spearman')
In [580]:
plt.plot(np.array(pool_size), R, "-o")
plt.xlabel("Pool Size")
plt.ylabel("R_spearman")
Out[580]:
Text(0,0.5,'R_spearman')
In [410]:
200/16
Out[410]:
12.5

Evauate