Goal

  • explore differnet loss-function choices
In [4]:
# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/concise/utils/plot.py:115: FutureWarning: arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.
  min_coords = np.vstack(data.min(0) for data in polygons_data).min(0)
/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/concise/utils/plot.py:116: FutureWarning: arrays to stack must be passed as a "sequence" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.
  max_coords = np.vstack(data.max(0) for data in polygons_data).max(0)
In [5]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [6]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/profile/"
output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
In [7]:
create_tf_session(1)
Out[7]:
<tensorflow.python.client.session.Session at 0x7f90088a87b8>
In [8]:
from basepair.cli.schemas import HParams

from basepair.models import seq_multitask
In [9]:
hparams = HParams.load(model_dir / 'hparams.yaml')
ds = DataSpec.load(model_dir / 'dataspec.yaml')
In [ ]:
model.compi
In [12]:
from basepair.datasets import get_StrandedProfile_datasets2
In [13]:
ds.path
Out[13]:
'dataspec.yml'
In [14]:
train,valid = get_StrandedProfile_datasets2(model_dir / 'dataspec.yaml', peak_width=1000)
In [16]:
valid
Out[16]:
[('valid-peaks', <basepair.datasets.StrandedProfile at 0x7f8ebc46fd30>),
 ('train-peaks', <basepair.datasets.StrandedProfile at 0x7f8ebc351d68>)]
In [15]:
train = train.load_all(num_workers=10)
1913it [00:25, 75.21it/s]                           
In [17]:
valid = valid[0][1].load_all(num_workers=10)
599it [00:08, 70.61it/s]                         
In [19]:
train['targets'].keys()
Out[19]:
dict_keys(['profile/Oct4', 'profile/Sox2', 'profile/Nanog', 'profile/Klf4', 'counts/Oct4', 'counts/Sox2', 'counts/Nanog', 'counts/Klf4'])
In [ ]:
# create random output dir
In [27]:
output_dir = Path(f"{ddir}/processed/chipnexus/exp/models/loss-functions")
In [31]:
import datetime
In [43]:
def time_now():
    from time import gmtime, strftime

    return strftime("%Y-%m-%d_%H:%M:%S", gmtime())

def run_dir(output_dir):
    import os
    outdir = os.path.join(output_dir, "run_"+time_now())
    os.makedirs(outdir)
    return outdir
In [225]:
import keras
In [287]:
from basepair.functions import softmax
In [341]:
import keras.backend as K
import keras.activations as ka
from basepair.models import *


def seq_multitask(filters=21,
                  conv1_kernel_size=21,
                  tconv_kernel_size=25,
                  n_dil_layers=6,
                  lr=0.004,
                  c_task_weight=100,
                  use_profile=True,
                  use_counts=True,
                  tasks=['Sox2', 'Oct4'],
                  outputs_per_task=2,
                  task_use_bias=False,
                  seq_len=1000,
                  pool_size=0,
                  connect_prev='add',
                  profile_loss='mc_multinomial_nll',
                  count_loss='mse'
                  ):  # TODO - automatically infer sequence length
    """
    Dense

    Args:
      c_task_weights: how to upweight the count-prediction task
      task_use_bias (bool or a list of bools): if True, a
        bias term is assumed to be provided at the input
    """
    # TODO - split the body of this model into multiple subparts:
    # - encoder
    # - profile_decoder
    # - profile_decoder_w_bias
    # - counts_decoder
    # - counts_decoder_w_bias
    if isinstance(outputs_per_task, int):
        outputs_per_task = [outputs_per_task for i in range(len(tasks))]
    else:
        assert len(tasks) == len(outputs_per_task)
    if isinstance(task_use_bias, bool):
        task_use_bias = [task_use_bias for i in range(len(tasks))]
    else:
        assert len(tasks) == len(task_use_bias)

    # TODO - build the reverse complement symmetry into the model
    inp = kl.Input(shape=(seq_len, 4), name='seq')
    first_conv = kl.Conv1D(filters,
                           kernel_size=conv1_kernel_size,
                           padding='same',
                           activation='relu')(inp)

    bias_profile_inputs = {task: kl.Input(shape=(seq_len, outputs_per_task[i]), name=f"bias/profile/{task}")
                           for i, task in enumerate(tasks) if task_use_bias[i]}
    bias_counts_inputs = [kl.Input(shape=(outputs_per_task[i], ), name=f"bias/counts/{task}")
                          for i, task in enumerate(tasks) if task_use_bias[i]]
    prev_layers = [first_conv]
    
    if connect_prev == 'add':
        merge_previous = kl.add
    elif connect_prev == 'concat':
        merge_previous = kl.concatenate
    else:
        raise ValueError("connect_prev needs to be 'add' or 'concat'")
    
    for i in range(1, n_dil_layers + 1):
        if i == 1:
            prev_sum = first_conv
        else:
            prev_sum = merge_previous(prev_layers)
        conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
        prev_layers.append(conv_output)

    combined_conv = merge_previous(prev_layers)
    # add one more layer in between
    if connect_prev == 'concat':
        combined_conv = kl.Conv1D(filters, 
                                  kernel_size=1,
                                  padding='same', 
                                  activation='relu')(combined_conv)

    # De-conv
    x = kl.Reshape((-1, 1, filters))(combined_conv)
    x = kl.Conv2DTranspose(sum(outputs_per_task), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
    out = kl.Reshape((-1, sum(outputs_per_task)))(x)
    # batch x seqlen x tasks*2 array
    # need another array of the same length
    # AvgPool if necessary
    if pool_size:
        out = kl.AvgPool1D(pool_size=pool_size, padding='valid')(out)
    # setup the output branches
    outputs = []
    losses = []
    loss_weights = []

    if use_profile:
        # TODO - use a different loss function for the same profiles
        start_idx = np.cumsum([0] + outputs_per_task[:-1])
        end_idx = np.cumsum(outputs_per_task)

        def get_output_name(task):
            if task in bias_profile_inputs:
                return "lambda/profile/" + task
            else:
                return "profile/" + task
        output = [kl.Lambda(lambda x, i, sidx, eidx: x[:, :, sidx:eidx],
                            output_shape=(seq_len, outputs_per_task[i]),
                            name=get_output_name(task),
                            arguments={"i": i, "sidx": start_idx[i], "eidx": end_idx[i]})(out)
                  for i, task in enumerate(tasks)]
        for i, task in enumerate(tasks):
            if task in bias_profile_inputs:
                output_with_bias = kl.concatenate([output[i],
                                                   bias_profile_inputs[task]], axis=-1)  # batch x seqlen x (2+2)
                output[i] = kl.Conv1D(outputs_per_task[i],
                                      1,
                                      name="profile/" + task)(output_with_bias)

        outputs += output
        if profile_loss != 'poisson':
            losses += [basepair.losses.get(f"{profile_loss}_{nt}") for nt in outputs_per_task]
        loss_weights += [1] * len(tasks)

    if use_counts:
        pooled = kl.GlobalAvgPool1D()(combined_conv)
        if bias_counts_inputs:
            pooled = kl.concatenate([pooled] + bias_counts_inputs, axis=-1)  # add bias as additional features

        activation = K.exp if count_loss == 'poisson' else 'linear'
        counts = [kl.Dense(outputs_per_task[i], name="counts/" + task, activation=activation)(pooled)
                  for i, task in enumerate(tasks)]
        outputs += counts
        losses += [count_loss] * len(tasks)
        loss_weights += [c_task_weight] * len(tasks)

    if profile_loss == 'poisson':
        # override the output and loss function
        losses = 'poisson'
        outputs = [kl.multiply([kl.Lambda(lambda x: ka.softmax(x, axis=-2))(outputs[i]),
                                kl.Lambda(lambda x: K.exp(x))(outputs[i+len(tasks)])], name=task)
                  for i,task in enumerate(tasks)]
        #outputs = [kl.multiply([ka.softmax(outputs[i], axis=-2), kl.Lambda(lambda x: K.exp(x))(outputs[i+len(tasks)])])
        #          for i,task in enumerate(tasks)]
        loss_weights = None
        
    model = Model([inp] + list(bias_profile_inputs.values()) + bias_counts_inputs, outputs)
    model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
    return model
In [334]:
keras.__version__
Out[334]:
'2.1.5'
In [361]:
tasks = list(ds.task_specs)
kwargs = OrderedDict([('filters', 64),
                      ('seq_len', 1000),
                      ('tasks', tasks),
                     ('conv1_kernel_size', 25),
                     ('tconv_kernel_size', 25),
                     ('n_dil_layers', 9),
                     ('lr', 0.004),
                     ('c_task_weight', 1),  # don't use it for the poisson loss
                     ('profile_loss', 'mc_multinomial_nll'), # 'poisson'),
                     ('count_loss', 'poisson')  # mse
                     ])
In [362]:
model = seq_multitask(**kwargs)
In [363]:
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
In [364]:
# ckp_file = os.path.join(run_dir(output_dir), 'model.h5')
# hist = model.fit(train['inputs'], [train['targets'][f'profile/{t}'] for t in tasks], 
#           batch_size=256, 
#           epochs=200,
#           validation_data=(valid['inputs'], [valid['targets'][f'profile/{t}'] for t in tasks]),
#           validation_split=0.2,
#           callbacks=[EarlyStopping(patience=5),
#                      History(),
#                      ModelCheckpoint(ckp_file, save_best_only=True)]
#          )
# # get the best model
# model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
In [367]:
train_targets = [train['targets'][f'profile/{t}'] for t in tasks] + [train['targets'][f'profile/{t}'].sum(axis=(1)) 
                                                        for t in tasks]
valid_targets = [valid['targets'][f'profile/{t}'] for t in tasks] + [valid['targets'][f'profile/{t}'].sum(axis=(1)) 
                                                        for t in tasks]
In [ ]:
ckp_file = os.path.join(run_dir(output_dir), 'model.h5')
hist = model.fit(train['inputs'], train_targets, 
          batch_size=256, 
          epochs=200,
          validation_data=(valid['inputs'], valid_targets),
          validation_split=0.2,
          callbacks=[EarlyStopping(patience=5),
                     History(),
                     ModelCheckpoint(ckp_file, save_best_only=True)]
         )
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
Train on 61205 samples, validate on 19137 samples
Epoch 1/200
61205/61205 [==============================] - 61s 995us/step - loss: 1307.8679 - profile/Oct4_loss: 958.6797 - profile/Sox2_loss: 601.6909 - profile/Nanog_loss: 815.4192 - profile/Klf4_loss: 699.9609 - counts/Oct4_loss: -607.9884 - counts/Sox2_loss: -301.1674 - counts/Nanog_loss: -362.5082 - counts/Klf4_loss: -496.2189 - val_loss: 306.9996 - val_profile/Oct4_loss: 939.1318 - val_profile/Sox2_loss: 595.4532 - val_profile/Nanog_loss: 794.0035 - val_profile/Klf4_loss: 675.0223 - val_counts/Oct4_loss: -943.0823 - val_counts/Sox2_loss: -360.4248 - val_counts/Nanog_loss: -858.9426 - val_counts/Klf4_loss: -534.1614
Epoch 2/200
61205/61205 [==============================] - 57s 931us/step - loss: 328.5366 - profile/Oct4_loss: 922.1844 - profile/Sox2_loss: 584.2270 - profile/Nanog_loss: 776.3516 - profile/Klf4_loss: 674.1250 - counts/Oct4_loss: -913.3605 - counts/Sox2_loss: -349.6661 - counts/Nanog_loss: -825.9776 - counts/Klf4_loss: -539.3472 - val_loss: 263.8716 - val_profile/Oct4_loss: 929.5950 - val_profile/Sox2_loss: 592.3644 - val_profile/Nanog_loss: 778.5075 - val_profile/Klf4_loss: 671.1834 - val_counts/Oct4_loss: -946.3210 - val_counts/Sox2_loss: -361.3781 - val_counts/Nanog_loss: -862.9281 - val_counts/Klf4_loss: -537.1515
Epoch 3/200
61205/61205 [==============================] - 57s 935us/step - loss: 294.5885 - profile/Oct4_loss: 917.8220 - profile/Sox2_loss: 581.4642 - profile/Nanog_loss: 761.8087 - profile/Klf4_loss: 670.9724 - counts/Oct4_loss: -914.5997 - counts/Sox2_loss: -350.3340 - counts/Nanog_loss: -829.7943 - counts/Klf4_loss: -542.7509 - val_loss: 296.8739 - val_profile/Oct4_loss: 927.6375 - val_profile/Sox2_loss: 592.2585 - val_profile/Nanog_loss: 770.7176 - val_profile/Klf4_loss: 667.3212 - val_counts/Oct4_loss: -924.4453 - val_counts/Sox2_loss: -353.0119 - val_counts/Nanog_loss: -850.9596 - val_counts/Klf4_loss: -532.6440
Epoch 4/200
61205/61205 [==============================] - 56s 922us/step - loss: 286.2913 - profile/Oct4_loss: 915.7771 - profile/Sox2_loss: 580.4478 - profile/Nanog_loss: 756.3393 - profile/Klf4_loss: 668.8322 - counts/Oct4_loss: -913.0520 - counts/Sox2_loss: -349.8910 - counts/Nanog_loss: -829.6867 - counts/Klf4_loss: -542.4753 - val_loss: 234.0737 - val_profile/Oct4_loss: 925.7988 - val_profile/Sox2_loss: 589.8926 - val_profile/Nanog_loss: 765.0034 - val_profile/Klf4_loss: 665.0141 - val_counts/Oct4_loss: -945.9929 - val_counts/Sox2_loss: -361.5764 - val_counts/Nanog_loss: -865.4327 - val_counts/Klf4_loss: -538.6332
Epoch 5/200
61205/61205 [==============================] - 57s 927us/step - loss: 274.6142 - profile/Oct4_loss: 914.2557 - profile/Sox2_loss: 579.6427 - profile/Nanog_loss: 753.5346 - profile/Klf4_loss: 667.0715 - counts/Oct4_loss: -914.5346 - counts/Sox2_loss: -350.4429 - counts/Nanog_loss: -831.7296 - counts/Klf4_loss: -543.1832 - val_loss: 226.3022 - val_profile/Oct4_loss: 924.2079 - val_profile/Sox2_loss: 588.8716 - val_profile/Nanog_loss: 763.9809 - val_profile/Klf4_loss: 663.0628 - val_counts/Oct4_loss: -945.7217 - val_counts/Sox2_loss: -360.8130 - val_counts/Nanog_loss: -866.5131 - val_counts/Klf4_loss: -540.7731
Epoch 6/200
61205/61205 [==============================] - 57s 928us/step - loss: 268.3993 - profile/Oct4_loss: 913.1257 - profile/Sox2_loss: 578.9323 - profile/Nanog_loss: 751.2489 - profile/Klf4_loss: 665.7884 - counts/Oct4_loss: -914.7278 - counts/Sox2_loss: -350.5154 - counts/Nanog_loss: -831.6762 - counts/Klf4_loss: -543.7766 - val_loss: 217.0670 - val_profile/Oct4_loss: 922.8253 - val_profile/Sox2_loss: 588.1021 - val_profile/Nanog_loss: 759.6476 - val_profile/Klf4_loss: 662.2689 - val_counts/Oct4_loss: -946.5943 - val_counts/Sox2_loss: -361.6901 - val_counts/Nanog_loss: -866.3034 - val_counts/Klf4_loss: -541.1891
Epoch 7/200
28928/61205 [=============>................] - ETA: 27s - loss: 253.0225 - profile/Oct4_loss: 913.8248 - profile/Sox2_loss: 579.1849 - profile/Nanog_loss: 753.1606 - profile/Klf4_loss: 665.8231 - counts/Oct4_loss: -919.1467 - counts/Sox2_loss: -351.6918 - counts/Nanog_loss: -843.2728 - counts/Klf4_loss: -544.8597
In [350]:
preds = model.predict(valid['inputs'])
In [351]:
from basepair.functions import softmax
In [353]:
len(preds)
Out[353]:
8
In [290]:
y_pred_profile = {t: softmax(preds[i]) for i,t in enumerate(tasks)} 
y_pred_counts = {t: preds[i] for i,t in enumerate(tasks)} 
In [291]:
y_pred['Nanog'].shape
Out[291]:
(19137, 1000, 2)
In [292]:
valid['targets'][f'profile/{t}']
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-292-c5b4fe3e3645> in <module>
----> 1 valid['targets'][f'profile/{t}']

NameError: name 't' is not defined
In [294]:
from basepair.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman
In [295]:
bpm = BPNetMetric(tasks, pearson_spearman, PeakPredictionProfileMetric)
In [302]:
from basepair.plot.evaluate import regression_eval
In [306]:
t = 'Nanog'
fig,axes = plt.subplots(1, 4, figsize=get_figsize(1, 1/4), sharex=True, sharey=True)
for i,t in enumerate(tasks):
    regression_eval(np.log10(1+valid['targets'][f'profile/{t}'].sum(axis=(1,2))),
                    np.log10(1+y_pred[t].sum(axis=(1,2))),
                    ax=axes[i]
                   )
    axes[i].set_title(t)
In [313]:
ppm = PeakPredictionProfileMetric(pos_min_threshold = 0.015,
                                    neg_max_threshold = 0.005,
                                    required_min_pos_counts = 2.5,
                                    binsizes = [1])
In [ ]:
# per-base accuracy
In [315]:
{t:ppm(valid['targets'][f'profile/{t}'], 
    y_pred[t] / y_pred[t].sum(axis=-2, keepdims=True))['binsize=1']['auprc']
for t in tasks}
    
Out[315]:
{'Oct4': 0.03574602290766749,
 'Sox2': 0.047750038496828975,
 'Nanog': 0.1058989835519579,
 'Klf4': 0.06399314314664085}
In [58]:
from statsmodels.graphics.gofplots import qqplot
from scipy.stats import poisson, nbinom, norm
In [59]:
import scipy.stats as stats

Analyze the count distributions

TODO - get the count distributions

In [62]:
train['targets']['counts/Nanog'].shape
Out[62]:
(61205, 2)
In [96]:
per_base_counts = np.ravel(train['targets']['profile/Nanog'].sum(axis=-1))
total_counts = train['targets']['profile/Nanog'].sum(axis=(-1,-2))
In [97]:
plt.hist(per_base_counts, bins=100, log='y');
plt.xlabel("Per-base counts for Nanog");
In [100]:
plt.hist(np.log10(total_counts+1), bins=100, log='y');
plt.xlabel("1-kb counts for Nanog");
In [69]:
log_counts = train['targets']['counts/Nanog'].mean(axis=-1)
In [70]:
counts = np.exp(log_counts) - 1
In [64]:
counts.mean()
Out[64]:
4.7004776

Total count qqplots

In [163]:
total_counts.mean()
Out[163]:
384.30978
In [116]:
counts = total_counts
fig, axes= plt.subplots(1, 4, figsize=get_figsize(1.5, aspect=1/5))
ax=axes[0]
stats.probplot(counts, dist="norm", plot=ax);
ax.set_title("Normal")

ax=axes[1]
stats.probplot(np.log(counts+1), dist="norm", plot=ax);
ax.set_title("Log-normal")

ax=axes[2]
stats.probplot(counts, sparams=(counts.mean()), dist="poisson", plot=ax);
ax.set_title("Poisson")

ax=axes[3]
stats.probplot(counts, sparams=(0.1, 0.001), dist="nbinom", plot=ax);
ax.set_title("NB")
plt.tight_layout()
In [117]:
counts = np.random.choice(per_base_counts, 10000)
fig, axes= plt.subplots(1, 4, figsize=get_figsize(1.5, aspect=1/5))
ax=axes[0]
stats.probplot(counts, dist="norm", plot=ax);
ax.set_title("Normal")

ax=axes[1]
stats.probplot(np.log(counts+1), dist="norm", plot=ax);
ax.set_title("Log-normal")

ax=axes[2]
stats.probplot(counts, sparams=(counts.mean()), dist="poisson", plot=ax);
ax.set_title("Poisson")

ax=axes[3]
stats.probplot(counts, sparams=(0.1, 0.001), dist="nbinom", plot=ax);
ax.set_title("NB")
plt.tight_layout()
In [136]:
np.log(N + 1)
Out[136]:
4.61512051684126
In [194]:
N = 40

C = np.arange(-5, 5) + np.log(N)

plt.plot(C, np.exp(C) - N * C, label='Poisson loss');
plt.plot(C, N/2*(np.log(N) - C)**2, label='mse loss');
plt.vlines(np.log(N), 0, max(np.exp(C) - N * C));
plt.legend();

Try 1

  • Per-base Poisson log-likelihood where one constructs
$$\mu_j = \exp(C)*p_j\;,$$

where $C$ are the predicted log-counts and $p_j$ is the per-base probability.

In [195]:
from keras.losses import poisson
In [196]:
poisson??
Signature: poisson(y_true, y_pred)
Docstring: <no docstring>
Source:   
def poisson(y_true, y_pred):
    return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
File:      ~/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/keras/losses.py
Type:      function
In [ ]: