Goal

  • improve the count model

Tasks

  • [ ] load the data
  • [ ] compute the bottlenecks
  • [ ] add a simple linear model on top (trained with sklearn)

TODO

  • [ ] Update weights in the main BPNet model
    • make sure you also un-do the inverse scaling
  • [ ] re-do the scatterplots for the main figure
  • [ ] re-do the importance scores

Required files

-

In [1]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
Using TensorFlow backend.
In [2]:
paper_config()
In [3]:
from basepair.config import valid_chr, test_chr
In [4]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/profile/"
output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")

Load the data

In [5]:
dataspec_file = model_dir / "dataspec.yaml"
ds = DataSpec.load(dataspec_file)
tasks = list(ds.task_specs)
In [6]:
from basepair.cli.imp_score import ImpScoreFile
In [22]:
imp_file = ImpScoreFile(model_dir / "grad.all.h5")
In [23]:
seq = imp_file.f.f['/inputs'][:]

Get bottleneck predictions

In [15]:
create_tf_session(0)
Out[15]:
<tensorflow.python.client.session.Session at 0x7f619745c6a0>
In [16]:
bpnet = BPNetPredictor.from_mdir(model_dir)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-12-10 02:00:05,125 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-12-10 02:00:14,269 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [17]:
bpnet.model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
seq (InputLayer)                (None, 1000, 4)      0                                            
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 1000, 64)     6464        seq[0][0]                        
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 1000, 64)     12352       conv1d_1[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 1000, 64)     12352       add_1[0][0]                      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 1000, 64)     12352       add_2[0][0]                      
__________________________________________________________________________________________________
add_3 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, 1000, 64)     12352       add_3[0][0]                      
__________________________________________________________________________________________________
add_4 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, 1000, 64)     12352       add_4[0][0]                      
__________________________________________________________________________________________________
add_5 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
                                                                 conv1d_6[0][0]                   
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 1000, 64)     12352       add_5[0][0]                      
__________________________________________________________________________________________________
add_6 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
                                                                 conv1d_6[0][0]                   
                                                                 conv1d_7[0][0]                   
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, 1000, 64)     12352       add_6[0][0]                      
__________________________________________________________________________________________________
add_7 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
                                                                 conv1d_6[0][0]                   
                                                                 conv1d_7[0][0]                   
                                                                 conv1d_8[0][0]                   
__________________________________________________________________________________________________
conv1d_9 (Conv1D)               (None, 1000, 64)     12352       add_7[0][0]                      
__________________________________________________________________________________________________
add_8 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
                                                                 conv1d_6[0][0]                   
                                                                 conv1d_7[0][0]                   
                                                                 conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
__________________________________________________________________________________________________
conv1d_10 (Conv1D)              (None, 1000, 64)     12352       add_8[0][0]                      
__________________________________________________________________________________________________
add_9 (Add)                     (None, 1000, 64)     0           conv1d_1[0][0]                   
                                                                 conv1d_2[0][0]                   
                                                                 conv1d_3[0][0]                   
                                                                 conv1d_4[0][0]                   
                                                                 conv1d_5[0][0]                   
                                                                 conv1d_6[0][0]                   
                                                                 conv1d_7[0][0]                   
                                                                 conv1d_8[0][0]                   
                                                                 conv1d_9[0][0]                   
                                                                 conv1d_10[0][0]                  
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 1000, 1, 64)  0           add_9[0][0]                      
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 1000, 1, 8)   12808       reshape_1[0][0]                  
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 1000, 8)      0           conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
global_average_pooling1d_1 (Glo (None, 64)           0           add_9[0][0]                      
__________________________________________________________________________________________________
profile/Oct4 (Lambda)           (None, 1000, 2)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
profile/Sox2 (Lambda)           (None, 1000, 2)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
profile/Nanog (Lambda)          (None, 1000, 2)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
profile/Klf4 (Lambda)           (None, 1000, 2)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
counts/Oct4 (Dense)             (None, 2)            130         global_average_pooling1d_1[0][0] 
__________________________________________________________________________________________________
counts/Sox2 (Dense)             (None, 2)            130         global_average_pooling1d_1[0][0] 
__________________________________________________________________________________________________
counts/Nanog (Dense)            (None, 2)            130         global_average_pooling1d_1[0][0] 
__________________________________________________________________________________________________
counts/Klf4 (Dense)             (None, 2)            130         global_average_pooling1d_1[0][0] 
==================================================================================================
Total params: 130,960
Trainable params: 130,960
Non-trainable params: 0
__________________________________________________________________________________________________
In [18]:
bottleneck_model = Model(bpnet.model.inputs, bpnet.model.get_layer("add_9").output)
In [19]:
bottleneck_predictions = bottleneck_model.predict(seq, batch_size=32, verbose=1)
98428/98428 [==============================] - 51s 523us/step

Get target values

In [25]:
chromosomes = pd.Series(imp_file.f.f['/metadata/range/chr'][:])
In [26]:
# generate the data split
in_train = ~ chromosomes.isin(valid_chr + test_chr)
in_valid = chromosomes.isin(valid_chr)
in_test = chromosomes.isin(test_chr)

assert not np.any(in_train & in_valid)
assert not np.any(in_train & in_test)
assert not np.any(in_valid & in_test)
In [27]:
profiles = imp_file.get_profiles()
In [28]:
imp_file.close()
In [29]:
# Get the count matrix
total_counts = np.concatenate([profiles[t].sum(axis=1) for t in bpnet.tasks], axis=1)
columns_names = [f"{t}/{s}" for t in bpnet.tasks for s in ['pos', 'neg']]
total_counts = pd.DataFrame(total_counts.astype(int), columns = columns_names)
In [30]:
del profiles
In [31]:
total_counts.head()
Out[31]:
Oct4/pos Oct4/neg Sox2/pos Sox2/neg Nanog/pos Nanog/neg Klf4/pos Klf4/neg
0 7804 15399 6881 13611 3311 9240 2011 5439
1 7263 8935 2007 2749 5927 7832 195 201
2 3868 3063 752 638 1693 1426 1710 1337
3 2451 4505 713 1278 726 1012 700 1097
4 3267 3031 1639 1495 531 575 793 899

Save to file

In [32]:
from sklearn.preprocessing import StandardScaler
In [33]:
log_total_counts = np.log(1 + total_counts).values
scaler = StandardScaler()
In [34]:
train = (bottleneck_predictions[in_train], scaler.fit_transform(log_total_counts[in_train]))
valid = (bottleneck_predictions[in_valid], scaler.transform(log_total_counts[in_valid]))
test = (bottleneck_predictions[in_valid], scaler.transform(log_total_counts[in_test]))
In [35]:
train_averaged = train[0].mean(axis=1)
valid_averaged = valid[0].mean(axis=1)

Train the top model using sklearn

In [36]:
from basepair.plot.evaluate import regression_eval

Linear model

In [37]:
from sklearn.multioutput import MultiOutputRegressor
from sklearn.linear_model import LinearRegression
In [38]:
m = MultiOutputRegressor(LinearRegression())
In [39]:
m.fit(train_averaged, log_total_counts[in_train])  # Don't use the scaled version when training the linear model
y_pred = m.predict(valid_averaged)
In [92]:
fig, axes = plt.subplots(1, len(bpnet.tasks), figsize=get_figsize(frac=1.5, aspect=0.3), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(bpnet.tasks, axes)):
    s = slice(2*i, 2*(i+1))
    regression_eval(log_total_counts[in_valid][:,s].mean(axis=1), y_pred[:,s].mean(axis=1), alpha=0.05, task=a, ax=ax);
plt.tight_layout()

Do the model surgery on bpnet

In [93]:
replace_weights = False
plt.figure(figsize=(4,4))
for task_i, task in enumerate(tasks):
    W_lm = np.stack([m.estimators_[2*task_i + strand_i].coef_ for strand_i in range(2)], axis=-1)
    b_lm = np.stack([m.estimators_[2*task_i + strand_i].intercept_ for strand_i in range(2)])
    
    W, b = bpnet.model.get_layer(f'counts/{task}').get_weights()
    
    if replace_weights:
        bpnet.model.get_layer(f'counts/{task}').set_weights((W_lm, b_lm))
    plt.scatter(W.ravel(), W_lm.ravel(), alpha=0.5, s=4);
    plt.xlabel("Neural net")
    plt.ylabel("Linear model");
In [94]:
# Save back the model
# bpnet.model.save(model_dir / 'model.h5')

Random forrest

In [102]:
from sklearn.ensemble import RandomForestRegressor
In [103]:
m = MultiOutputRegressor(RandomForestRegressor(n_estimators=20), n_jobs=8)
In [104]:
m.fit(train_averaged, train[1])
y_pred = m.predict(valid_averaged)
In [105]:
fig, axes = plt.subplots(1, len(bpnet.tasks), figsize=get_figsize(frac=1.5, aspect=0.3), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(bpnet.tasks, axes)):
    s = slice(2*i, 2*(i+1))
    regression_eval(valid[1][:,s].mean(axis=1), y_pred[:,s].mean(axis=1), alpha=0.05, task=a, ax=ax);
plt.tight_layout()

Linear model on binned counts

In [108]:
from basepair.preproc import bin_counts
In [118]:
# bin and flatten
train_avg_pool = bin_counts(train[0], 50).reshape((len(train[0]), -1))
valid_avg_pool = bin_counts(valid[0], 50).reshape((len(valid[0]), -1))
In [119]:
m = MultiOutputRegressor(LinearRegression())
In [120]:
m.fit(train_avg_pool, train[1])
y_pred = m.predict(valid_avg_pool)
In [121]:
fig, axes = plt.subplots(1, len(bpnet.tasks), figsize=get_figsize(frac=1.5, aspect=0.3), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(bpnet.tasks, axes)):
    s = slice(2*i, 2*(i+1))
    regression_eval(valid[1][:,s].mean(axis=1), y_pred[:,s].mean(axis=1), alpha=0.05, task=a, ax=ax);
plt.tight_layout()

MLP

In [158]:
from keras.models import Sequential
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
i = 0
In [159]:
!mkdir -p {model_dir}/count-models
In [244]:
m = Sequential([kl.BatchNormalization(input_shape = train_averaged.shape[1:]),
                kl.Dense(64, activation='relu'),
                kl.BatchNormalization(),
                kl.Dense(len(tasks))])
m.compile(Adam(0.01), 'mse')
In [245]:
i+=1
ckp_file = str(model_dir / f"count-models/{i}.h5")
m.fit(train_averaged, train[1],
      batch_size=512,
      epochs=50, 
      callbacks=[EarlyStopping(patience=5, ),
                ModelCheckpoint(ckp_file, save_best_only=True, )],
      validation_data=(valid_averaged, valid[1]))
m = load_model(ckp_file)
y_pred = m.predict(valid_averaged)
Train on 61205 samples, validate on 19137 samples
Epoch 1/50
61205/61205 [==============================] - 6s 104us/step - loss: 0.7021 - val_loss: 0.6909
Epoch 2/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6283 - val_loss: 0.6661
Epoch 3/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6200 - val_loss: 0.6757
Epoch 4/50
61205/61205 [==============================] - 1s 16us/step - loss: 0.6187 - val_loss: 0.6722
Epoch 5/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6146 - val_loss: 0.6647
Epoch 6/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6146 - val_loss: 0.6623
Epoch 7/50
61205/61205 [==============================] - 1s 13us/step - loss: 0.6119 - val_loss: 0.6699
Epoch 8/50
61205/61205 [==============================] - 1s 17us/step - loss: 0.6105 - val_loss: 0.6781
Epoch 9/50
61205/61205 [==============================] - 1s 16us/step - loss: 0.6091 - val_loss: 0.6620
Epoch 10/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6078 - val_loss: 0.6632
Epoch 11/50
61205/61205 [==============================] - 1s 17us/step - loss: 0.6072 - val_loss: 0.6583
Epoch 12/50
61205/61205 [==============================] - 1s 17us/step - loss: 0.6075 - val_loss: 0.6658
Epoch 13/50
61205/61205 [==============================] - 1s 17us/step - loss: 0.6054 - val_loss: 0.6760
Epoch 14/50
61205/61205 [==============================] - 1s 15us/step - loss: 0.6058 - val_loss: 0.6730
Epoch 15/50
61205/61205 [==============================] - 1s 16us/step - loss: 0.6051 - val_loss: 0.6717
Epoch 16/50
61205/61205 [==============================] - 1s 17us/step - loss: 0.6043 - val_loss: 0.6605
In [246]:
fig, axes = plt.subplots(1, len(bpnet.tasks), figsize=get_figsize(frac=1.5, aspect=0.3), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(bpnet.tasks, axes)):
    s = slice(2*i, 2*(i+1))
    regression_eval(valid[1][:,s].mean(axis=1), y_pred[:,s].mean(axis=1), alpha=0.05, task=a, ax=ax);
plt.tight_layout()

MLP on strided pooling

In [226]:
mp = Sequential([kl.AveragePooling1D(100, input_shape = train[0].shape[1:]),
                kl.Flatten()])
train_maxpool = mp.predict(train[0])
valid_maxpool = mp.predict(valid[0])
In [238]:
m = Sequential([kl.BatchNormalization(input_shape = train_maxpool.shape[1:]),
                kl.Dense(64, activation='relu'),
                kl.Dropout(0.3),
                kl.BatchNormalization(),
                kl.Dense(len(tasks))])
m.compile(Adam(0.01), 'mse')
In [239]:
i+=1
ckp_file = str(model_dir / f"count-models/{i}.h5")
m.fit(train_maxpool, train[1],
      batch_size=512,
      epochs=50, 
      callbacks=[EarlyStopping(patience=5, ),
                ModelCheckpoint(ckp_file, save_best_only=True, )],
      validation_data=(valid_maxpool, valid[1]))
m = load_model(ckp_file)
y_pred = m.predict(valid_maxpool)
Train on 61205 samples, validate on 19137 samples
Epoch 1/50
61205/61205 [==============================] - 8s 130us/step - loss: 0.7910 - val_loss: 0.6648
Epoch 2/50
61205/61205 [==============================] - 1s 22us/step - loss: 0.6364 - val_loss: 0.6721
Epoch 3/50
61205/61205 [==============================] - 1s 21us/step - loss: 0.6247 - val_loss: 0.6531
Epoch 4/50
61205/61205 [==============================] - 1s 22us/step - loss: 0.6178 - val_loss: 0.6608
Epoch 5/50
61205/61205 [==============================] - 1s 20us/step - loss: 0.6122 - val_loss: 0.6476
Epoch 6/50
61205/61205 [==============================] - 1s 19us/step - loss: 0.6076 - val_loss: 0.6447
Epoch 7/50
61205/61205 [==============================] - 1s 21us/step - loss: 0.6050 - val_loss: 0.6543
Epoch 8/50
61205/61205 [==============================] - 1s 22us/step - loss: 0.6018 - val_loss: 0.6582
Epoch 9/50
61205/61205 [==============================] - 1s 18us/step - loss: 0.5990 - val_loss: 0.6669
Epoch 10/50
61205/61205 [==============================] - 1s 19us/step - loss: 0.5945 - val_loss: 0.6498
Epoch 11/50
61205/61205 [==============================] - 1s 18us/step - loss: 0.5913 - val_loss: 0.6617
In [240]:
fig, axes = plt.subplots(1, len(bpnet.tasks), figsize=get_figsize(frac=1.5, aspect=0.3), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(bpnet.tasks, axes)):
    s = slice(2*i, 2*(i+1))
    regression_eval(valid[1][:,s].mean(axis=1), y_pred[:,s].mean(axis=1), alpha=0.05, task=a, ax=ax);
plt.tight_layout()