Goal

  • develop an activity model predicting:
    • H3K27ac
    • Pol II
    • GRO-seq

metaplots html

Design descisions

  • dataset = all the chipnexus peaks
  • output = total number of counts in the +- 1kb of the region center

Tasks

  • [x] write the dataloader
  • [x] load all the data
  • [x] plot the response variable distribution (histogram, QQ-plots)
  • [x] get the predictions across all the regions from bpnet
    • [x] save to the hdf5 file
  • [x] train a simple model on top
  • [x] evaluate using simple scatterplots
  • [ ] exclude promoter-regions and re-train

Other tasks

  • [ ] fine-tune the whole model
  • [ ] get the importance scores
In [1]:
from basepair.imports import *
import pybedtools
from genomelake.extractors import BigwigExtractor
from basepair.extractors import StrandedBigWigExtractor, bw_extract
from kipoiseq.transforms import ResizeInterval
from basepair.modisco.table import ModiscoData
from pybedtools import BedTool
from basepair.cli.modisco import load_ranges
from basepair.plot.heatmaps import (heatmap_importance_profile, normalize, multiple_heatmap_stranded_profile,
                                    heatmap_stranded_profile)
from basepair.plot.profiles import plot_stranded_profile
import hvplot.pandas
from basepair.plots import regression_eval
Using TensorFlow backend.
In [2]:
create_tf_session(1)
Out[2]:
<tensorflow.python.client.session.Session at 0x7f45ef72e1d0>

Load the data

In [3]:
df = pd.read_csv(f"{ddir}/processed/chipnexus/external-data.tsv", sep='\t')
df = df.set_index('assay')
In [4]:
df
Out[4]:
axis path
assay
DNase 0 /srv/scratch/avsec/wo...
DNase 0 /srv/scratch/avsec/wo...
DNase-HINT 0 /srv/scratch/avsec/wo...
... ... ...
Groseq 0 /srv/scratch/avsec/wo...
MNase-wt 0 /srv/scratch/avsec/wo...
MNase-h4 0 /srv/scratch/avsec/wo...

16 rows × 2 columns

Write the bed file

In [9]:
from basepair.cli.schemas import DataSpec

ds = DataSpec.load(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf-sall/dataspec-incl-Sall4.yml")

!mkdir -p {ddir}/processed/activity/data/

def read_factor(factor, filename):
    df = pd.read_table(filename, header=None, usecols=[0, 1, 2])
    df[3] = factor
    df.columns = ['chrom', 'start', 'end', 'name']
    return df

dfc = pd.concat([read_factor(k, ds.task_specs[k].peaks) for k in ds.task_specs], axis=0)

dfc.to_csv(f"{ddir}/processed/activity/data/peaks.bed", sep='\t', index=False, header=False)

Prepare the bigwigs

In [10]:
assays = ['H3K27ac', 'PolII', 'Groseq']
In [11]:
def tolist(s):
    if isinstance(s, str):
        return [s]
    else:
        return list(s)
In [12]:
bigwigs = {a: tolist(df.loc[a].path) for a in assays}
In [13]:
bigwigs
Out[13]:
{'H3K27ac': ['/srv/scratch/avsec/workspace/chipnexus/data/raw/2018-10-13-histone-chipseq-PMID-28483418/H3K27ac_ChIP-seq_WT_rep1_blacklisted.bw',
  '/srv/scratch/avsec/workspace/chipnexus/data/raw/2018-10-13-histone-chipseq-PMID-28483418/H3K27ac_ChIP-seq_WT_rep2_blacklisted.bw'],
 'PolII': ['/srv/scratch/avsec/workspace/chipnexus/data/raw/2018-10-13-histone-chipseq-PMID-28483418/PolII_ChIP-seq_WT_rep1_blacklisted.bw',
  '/srv/scratch/avsec/workspace/chipnexus/data/raw/2018-10-13-histone-chipseq-PMID-28483418/PolII_ChIP-seq_WT_rep2_blacklisted.bw'],
 'Groseq': ['/srv/scratch/avsec/workspace/chipnexus/data/raw/2018-10-15-groseq-from-melanie/GRO-seq_WT_1_blacklisted.bw']}

Write the dataloader

In [14]:
from basepair.config import valid_chr, test_chr

from basepair.datasets import *
In [15]:
dl = ActivityDataset(f"{ddir}/processed/activity/data/peaks.bed", ds.fasta_file, bigwigs, 
                     excl_chromosomes=valid_chr + test_chr)
In [16]:
# load all
train = dl.load_all(num_workers=10)
100%|██████████| 2646/2646 [00:39<00:00, 66.90it/s] 
In [17]:
dfy = pd.DataFrame(train['targets'])
In [18]:
valid = ActivityDataset(f"{ddir}/processed/activity/data/peaks.bed", ds.fasta_file, bigwigs, 
                        incl_chromosomes=valid_chr).load_all(num_workers=10)
100%|██████████| 832/832 [00:06<00:00, 126.95it/s]
In [19]:
dfy_valid = pd.DataFrame(valid['targets'])
In [20]:
test = ActivityDataset(f"{ddir}/processed/activity/data/peaks.bed", ds.fasta_file, bigwigs, 
                       incl_chromosomes=test_chr).load_all(num_workers=10)
100%|██████████| 781/781 [00:07<00:00, 110.59it/s]
In [21]:
dfy_test = pd.DataFrame(test['targets'])
In [22]:
dfy.head()
Out[22]:
H3K27ac PolII Groseq
0 72177.0 230052.0 1577.0
1 440983.0 164513.0 320665.0
2 135278.0 83775.0 228978.0
3 44406.0 22818.0 77130.0
4 246385.0 90010.0 282323.0
In [23]:
len(dl)
Out[23]:
84659
In [24]:
dl[0]
Out[24]:
{'inputs': {'seq': array([[0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         ...,
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.]], dtype=float32)},
 'targets': {'H3K27ac': 72177.0, 'PolII': 230052.0, 'Groseq': 1577.0},
 'metadata': {'ranges': GenomicRanges(chr='chrX', start=143482544, end=143483544, id='0', strand='*'),
  'ranges_wide': GenomicRanges(chr='chrX', start=143482044, end=143484044, id='Oct4', strand='.'),
  'name': 'Oct4'}}

Output distribution

Histograms

QQ-plots

In [26]:
import scipy.stats as stats
In [27]:
fig, axes = plt.subplots(1, len(assays), figsize=(9, 3), sharex=True, sharey=True)
for a, ax in zip(assays, axes):
    stats.probplot(np.log10(1 + dfy[a]), dist="norm", plot=ax);
    ax.set_title(a)
plt.tight_layout()
In [28]:
%matplotlib inline
paper_config()

Get bottleneck predictions

In [29]:
from basepair.BPNet import BPNetPredictor
from keras.models import Model, Sequential
import keras.layers as kl
In [30]:
mdir = f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf-sall/models/default/"
In [34]:
bpnet = BPNetPredictor.from_mdir(mdir)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-02-16 12:00:40,703 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-02-16 12:00:51,630 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [35]:
bpnet.tasks
Out[35]:
['Oct4', 'Sox2', 'Nanog', 'Klf4', 'Sall4']
In [36]:
bpnet.model.get_layer("add_9").output
Out[36]:
<tf.Tensor 'add_9/add_8:0' shape=(?, 1000, 64) dtype=float32>
In [37]:
bottleneck_model = Model(bpnet.model.inputs, bpnet.model.get_layer("add_9").output)
In [38]:
bottleneck_predictions = bottleneck_model.predict(train['inputs'], batch_size=32, verbose=1)
84659/84659 [==============================] - 31s 361us/step
In [39]:
bottleneck_predictions_valid = bottleneck_model.predict(valid['inputs'], batch_size=32, verbose=1)
bottleneck_predictions_test = bottleneck_model.predict(test['inputs'], batch_size=32, verbose=1)
26615/26615 [==============================] - 9s 333us/step
24971/24971 [==============================] - 8s 322us/step
In [40]:
bottleneck_predictions.shape
Out[40]:
(84659, 1000, 64)
In [41]:
top_model = Sequential([
    kl.GlobalAvgPool1D(input_shape=(1000, 64)),
    kl.Dense(3)
])
In [42]:
top_model = Sequential([
    kl.MaxPool1D(pool_size=50, input_shape=(1000, 64)),
    kl.Flatten(),
    kl.Dense(64, activation='relu'),
    # kl.Dropout(.5)
    kl.Dense(3)
])
In [43]:
from concise.metrics import var_explained
from keras.callbacks import EarlyStopping
from sklearn.preprocessing import StandardScaler
In [44]:
top_model.compile("adam", 'mse', metrics=[var_explained])
In [45]:
preproc = StandardScaler()
In [46]:
y = preproc.fit_transform(np.log10(1 + dfy).values)
y_valid = preproc.transform(np.log10(1 + dfy_valid).values)
y_test = preproc.transform(np.log10(1 + dfy_test).values)
In [ ]:
top_model.fit(bottleneck_predictions, y, batch_size=512,
              epochs=100,
              validation_data=(bottleneck_predictions_valid, y_valid),
              callbacks=[EarlyStopping(patience=5)]
             )
Train on 84659 samples, validate on 26615 samples
Epoch 1/100
84659/84659 [==============================] - 42s 494us/step - loss: 1.0608 - var_explained: -0.0255 - val_loss: 0.8753 - val_var_explained: 0.1291
Epoch 2/100
84659/84659 [==============================] - 39s 457us/step - loss: 0.8489 - var_explained: 0.1575 - val_loss: 0.8646 - val_var_explained: 0.1396
Epoch 3/100
84659/84659 [==============================] - 40s 469us/step - loss: 0.8416 - var_explained: 0.1707 - val_loss: 0.8451 - val_var_explained: 0.1614
Epoch 4/100
84659/84659 [==============================] - 37s 439us/step - loss: 0.8265 - var_explained: 0.1863 - val_loss: 0.8386 - val_var_explained: 0.1714
Epoch 5/100
84659/84659 [==============================] - 36s 430us/step - loss: 0.8282 - var_explained: 0.1908 - val_loss: 0.8342 - val_var_explained: 0.1738
Epoch 6/100
84659/84659 [==============================] - 32s 384us/step - loss: 0.8104 - var_explained: 0.1993 - val_loss: 0.8282 - val_var_explained: 0.1770
Epoch 7/100
84659/84659 [==============================] - 35s 409us/step - loss: 0.8059 - var_explained: 0.2023 - val_loss: 0.8254 - val_var_explained: 0.1833
Epoch 8/100
84659/84659 [==============================] - 34s 397us/step - loss: 0.8054 - var_explained: 0.2071 - val_loss: 0.8289 - val_var_explained: 0.1781
Epoch 9/100
84659/84659 [==============================] - 33s 393us/step - loss: 0.7987 - var_explained: 0.2121 - val_loss: 0.8442 - val_var_explained: 0.1791
Epoch 10/100
84659/84659 [==============================] - 36s 424us/step - loss: 0.7982 - var_explained: 0.2121 - val_loss: 0.8396 - val_var_explained: 0.1620
Epoch 11/100
84659/84659 [==============================] - 32s 375us/step - loss: 0.8057 - var_explained: 0.2075 - val_loss: 0.8166 - val_var_explained: 0.1892
Epoch 12/100
84659/84659 [==============================] - 36s 422us/step - loss: 0.7875 - var_explained: 0.2186 - val_loss: 0.8162 - val_var_explained: 0.1868
Epoch 13/100
84659/84659 [==============================] - 34s 406us/step - loss: 0.7882 - var_explained: 0.2203 - val_loss: 0.8169 - val_var_explained: 0.1915
Epoch 14/100
84659/84659 [==============================] - 36s 430us/step - loss: 0.7853 - var_explained: 0.2245 - val_loss: 0.8113 - val_var_explained: 0.1941
Epoch 15/100
84659/84659 [==============================] - 36s 428us/step - loss: 0.7892 - var_explained: 0.2250 - val_loss: 0.8116 - val_var_explained: 0.1916
Epoch 16/100
84659/84659 [==============================] - 34s 398us/step - loss: 0.7803 - var_explained: 0.2289 - val_loss: 0.8351 - val_var_explained: 0.1890
Epoch 17/100
84659/84659 [==============================] - 38s 453us/step - loss: 0.7847 - var_explained: 0.2255 - val_loss: 0.8124 - val_var_explained: 0.1940
Epoch 18/100
84659/84659 [==============================] - 30s 353us/step - loss: 0.7790 - var_explained: 0.2315 - val_loss: 0.8236 - val_var_explained: 0.1964
Epoch 19/100
47104/84659 [===============>..............] - ETA: 10s - loss: 0.7819 - var_explained: 0.2311

Evaluate

In [ ]:
ypred_valid = top_model.predict(bottleneck_predictions_valid)
In [52]:
ypred_valid.shape
Out[52]:
(26615, 3)
In [53]:
fig, axes = plt.subplots(len(assays), 1, figsize=(5, 11), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(assays, axes)):
    regression_eval(ypred_valid[:,i], y_valid[:,i], alpha=0.05, task=a, ax=ax);
plt.tight_layout()

Fine-tune

In [122]:
whole_model = Sequential([bottleneck_model, top_model])
In [123]:
whole_model.compile("adam", 'mse', metrics=[var_explained])
In [127]:
whole_model.fit(train['inputs']['seq'], y, batch_size=512,
                epochs=100,
                validation_data=(valid['inputs']['seq'], y_valid),
                callbacks=[EarlyStopping(patience=5)])
Train on 61205 samples, validate on 19137 samples
Epoch 1/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.6910 - var_explained: 0.3158 - val_loss: 0.8573 - val_var_explained: 0.1735
Epoch 2/100
61205/61205 [==============================] - 55s 898us/step - loss: 0.6619 - var_explained: 0.3485 - val_loss: 0.8192 - val_var_explained: 0.1749
Epoch 3/100
61205/61205 [==============================] - 55s 900us/step - loss: 0.6298 - var_explained: 0.3824 - val_loss: 0.8426 - val_var_explained: 0.1515
Epoch 4/100
61205/61205 [==============================] - 55s 901us/step - loss: 0.5947 - var_explained: 0.4155 - val_loss: 0.9512 - val_var_explained: 0.1563
Epoch 5/100
61205/61205 [==============================] - 55s 901us/step - loss: 0.5648 - var_explained: 0.4474 - val_loss: 0.8477 - val_var_explained: 0.1414
Epoch 6/100
61205/61205 [==============================] - 55s 902us/step - loss: 0.5450 - var_explained: 0.4743 - val_loss: 0.8559 - val_var_explained: 0.1323
Epoch 7/100
61205/61205 [==============================] - 55s 901us/step - loss: 0.5167 - var_explained: 0.5014 - val_loss: 0.8765 - val_var_explained: 0.1221
Out[127]:
<keras.callbacks.History at 0x7f08c54f6400>
In [128]:
ypred_valid = whole_model.predict(valid['inputs']['seq'])
In [129]:
ypred_valid.shape
Out[129]:
(19137, 3)
In [130]:
fig, axes = plt.subplots(len(assays), 1, figsize=(5, 11), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(assays, axes)):
    regression_eval(ypred_valid[:,i], y_valid[:,i], alpha=0.05, task=a, ax=ax);
plt.tight_layout()

Train from scratch

In [141]:
import keras.backend as K
from keras.models import load_model

def reset_weights(model):
    session = K.get_session()
    for layer in model.layers: 
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)
In [137]:
whole_model.save(f"{ddir}/processed/activity/models/dense/fine-tuned.h5")
In [145]:
reinitialized_model = load_model(f"{ddir}/processed/activity/models/dense/fine-tuned.h5")
In [163]:
reset_weights(reinitialized_model.layers[0])
reset_weights(reinitialized_model.layers[1])
In [165]:
reinitialized_model.compile("adam", 'mse', metrics=[var_explained])
In [166]:
reinitialized_model.fit(train['inputs']['seq'], y, batch_size=512,
                epochs=100,
                validation_data=(valid['inputs']['seq'], y_valid),
                callbacks=[EarlyStopping(patience=5)])
Train on 61205 samples, validate on 19137 samples
Epoch 1/100
61205/61205 [==============================] - 58s 942us/step - loss: 1.2780 - var_explained: -0.2762 - val_loss: 0.9018 - val_var_explained: 0.0977
Epoch 2/100
61205/61205 [==============================] - 55s 895us/step - loss: 0.8525 - var_explained: 0.1490 - val_loss: 0.8654 - val_var_explained: 0.1464
Epoch 3/100
61205/61205 [==============================] - 55s 898us/step - loss: 0.8386 - var_explained: 0.1638 - val_loss: 0.8479 - val_var_explained: 0.1573
Epoch 4/100
61205/61205 [==============================] - 55s 897us/step - loss: 0.8318 - var_explained: 0.1739 - val_loss: 0.8420 - val_var_explained: 0.1624
Epoch 5/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.8190 - var_explained: 0.1854 - val_loss: 0.8296 - val_var_explained: 0.1719
Epoch 6/100
61205/61205 [==============================] - 55s 900us/step - loss: 0.8099 - var_explained: 0.1979 - val_loss: 0.8332 - val_var_explained: 0.1779
Epoch 7/100
61205/61205 [==============================] - 55s 902us/step - loss: 0.7944 - var_explained: 0.2122 - val_loss: 0.8246 - val_var_explained: 0.1844
Epoch 8/100
61205/61205 [==============================] - 55s 898us/step - loss: 0.7846 - var_explained: 0.2242 - val_loss: 0.8429 - val_var_explained: 0.1800
Epoch 9/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7831 - var_explained: 0.2324 - val_loss: 0.8144 - val_var_explained: 0.1911
Epoch 10/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7681 - var_explained: 0.2446 - val_loss: 0.8132 - val_var_explained: 0.1915
Epoch 11/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7558 - var_explained: 0.2576 - val_loss: 0.8215 - val_var_explained: 0.1958
Epoch 12/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7397 - var_explained: 0.2733 - val_loss: 0.8182 - val_var_explained: 0.1868
Epoch 13/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7221 - var_explained: 0.2911 - val_loss: 0.8235 - val_var_explained: 0.1914
Epoch 14/100
61205/61205 [==============================] - 55s 899us/step - loss: 0.7095 - var_explained: 0.3094 - val_loss: 0.8200 - val_var_explained: 0.1966
Epoch 15/100
61205/61205 [==============================] - 55s 901us/step - loss: 0.6958 - var_explained: 0.3258 - val_loss: 0.8235 - val_var_explained: 0.1777
Out[166]:
<keras.callbacks.History at 0x7f0551c6a0f0>
In [167]:
ypred_valid = reinitialized_model.predict(valid['inputs']['seq'])
In [168]:
ypred_valid.shape
Out[168]:
(19137, 3)
In [169]:
fig, axes = plt.subplots(len(assays), 1, figsize=(5, 11), sharex=True, sharey=True)
for i, (a, ax) in enumerate(zip(assays, axes)):
    regression_eval(ypred_valid[:,i], y_valid[:,i], alpha=0.05, task=a, ax=ax);
plt.tight_layout()