Goals

  • correct-for biases in ChIP-seq

how are model predictions correlated with the bias?

scatter-plot:

  • model predictions for total counts (before adding the bias term) vs bias total counts (log scale only the latter)
  • model predictions for local counts before adding the bias term (in say 50bp bins) {here you bin_counts the predictions} vs bias local counts (log scale only the latter)

basically we hope that our model output doesn't correlate with the bias

In [1]:
# Imports
from basepair.imports import *
import warnings
warnings.filterwarnings('ignore')
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
WARNING: Font Arial not installed and is required by
In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
In [3]:
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-seq.dataspec.yml'
In [4]:
from basepair.datasets import get_StrandedProfile_datasets2

train, valid = get_StrandedProfile_datasets2(
    dataspec=dataspec_path,
    peak_width = 1000,
    seq_width=1000,
    include_metadata = False,
    taskname_first = True,  # so that the output labels will be "{task}/profile"
    exclude_chr = ['chrX', 'chrY'],
    profile_bias_pool_size=None)
In [5]:
valid = valid[0][1]
In [6]:
!cat {dataspec_path}
task_specs:
  Oct4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/idr-optimal-set.summit.bed.gz
  Sox2:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/idr-optimal-set.summit.bed.gz
  Nanog:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/idr-optimal-set.summit.bed.gz

bias_specs:
  input:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/input-control/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/input-control/counts.neg.bw
    tasks:
      - Oct4
      - Sox2
      - Nanog
    
fasta_file: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/mm10_no_alt_analysis_set_ENCODE.fasta
In [7]:
ds = DataSpec.load(dataspec_path)
In [8]:
tasks = list(ds.task_specs)
In [9]:
tasks
Out[9]:
['Oct4', 'Sox2', 'Nanog']
In [10]:
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
from basepair.plot.evaluate import regression_eval
In [11]:
m = multihead_seq_model(tasks=tasks,
                        filters=64,
                        n_dil_layers=6,
                        conv1_kernel_size=25,tconv_kernel_size=50, 
                        b_loss_weight=0, c_loss_weight=10, p_loss_weight=1, 
                        use_bias=True,
                        lr=0.004, padding='same', batchnorm=False)
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-02-27 23:35:29,192 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-02-27 23:35:41,473 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [12]:
output_dir='/tmp/exp/m3-wb'
In [13]:
!mkdir -p {output_dir}
In [14]:
!rm /tmp/exp/m3-wb/model.h5
In [15]:
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
In [16]:
tr.train(epochs=100)
Epoch 1/100
163/163 [==============================] - 90s 555ms/step - loss: 1359.9170 - Oct4/profile_loss: 338.8134 - Oct4/counts_loss: 1.8260 - Sox2/profile_loss: 369.9114 - Sox2/counts_loss: 0.4749 - Nanog/profile_loss: 616.0497 - Nanog/counts_loss: 1.2134 - val_loss: 1308.2303 - val_Oct4/profile_loss: 333.6591 - val_Oct4/counts_loss: 0.5528 - val_Sox2/profile_loss: 363.9022 - val_Sox2/counts_loss: 0.3144 - val_Nanog/profile_loss: 595.4966 - val_Nanog/counts_loss: 0.6500
Epoch 2/100
163/163 [==============================] - 85s 521ms/step - loss: 1281.1539 - Oct4/profile_loss: 326.3831 - Oct4/counts_loss: 0.5452 - Sox2/profile_loss: 359.7132 - Sox2/counts_loss: 0.3121 - Nanog/profile_loss: 580.0155 - Nanog/counts_loss: 0.6469 - val_loss: 1279.3761 - val_Oct4/profile_loss: 326.8178 - val_Oct4/counts_loss: 0.5382 - val_Sox2/profile_loss: 358.8862 - val_Sox2/counts_loss: 0.2980 - val_Nanog/profile_loss: 578.8158 - val_Nanog/counts_loss: 0.6494
Epoch 3/100
163/163 [==============================] - 80s 489ms/step - loss: 1260.4263 - Oct4/profile_loss: 321.7109 - Oct4/counts_loss: 0.5258 - Sox2/profile_loss: 356.2882 - Sox2/counts_loss: 0.2895 - Nanog/profile_loss: 567.8080 - Nanog/counts_loss: 0.6466 - val_loss: 1268.2871 - val_Oct4/profile_loss: 324.8762 - val_Oct4/counts_loss: 0.5140 - val_Sox2/profile_loss: 356.9931 - val_Sox2/counts_loss: 0.2719 - val_Nanog/profile_loss: 572.0614 - val_Nanog/counts_loss: 0.6498
Epoch 4/100
163/163 [==============================] - 78s 481ms/step - loss: 1249.5960 - Oct4/profile_loss: 319.7853 - Oct4/counts_loss: 0.5035 - Sox2/profile_loss: 354.4200 - Sox2/counts_loss: 0.2699 - Nanog/profile_loss: 561.2952 - Nanog/counts_loss: 0.6361 - val_loss: 1263.1316 - val_Oct4/profile_loss: 323.6433 - val_Oct4/counts_loss: 0.4983 - val_Sox2/profile_loss: 356.9825 - val_Sox2/counts_loss: 0.2562 - val_Nanog/profile_loss: 568.6514 - val_Nanog/counts_loss: 0.6309
Epoch 5/100
163/163 [==============================] - 81s 496ms/step - loss: 1243.3368 - Oct4/profile_loss: 318.6717 - Oct4/counts_loss: 0.4954 - Sox2/profile_loss: 353.6480 - Sox2/counts_loss: 0.2526 - Nanog/profile_loss: 557.2415 - Nanog/counts_loss: 0.6295 - val_loss: 1259.0238 - val_Oct4/profile_loss: 323.0029 - val_Oct4/counts_loss: 0.4854 - val_Sox2/profile_loss: 356.0501 - val_Sox2/counts_loss: 0.2384 - val_Nanog/profile_loss: 566.5448 - val_Nanog/counts_loss: 0.6188
Epoch 6/100
163/163 [==============================] - 80s 493ms/step - loss: 1239.9223 - Oct4/profile_loss: 318.3298 - Oct4/counts_loss: 0.4797 - Sox2/profile_loss: 352.9735 - Sox2/counts_loss: 0.2399 - Nanog/profile_loss: 555.2824 - Nanog/counts_loss: 0.6140 - val_loss: 1247.3692 - val_Oct4/profile_loss: 320.5650 - val_Oct4/counts_loss: 0.4998 - val_Sox2/profile_loss: 353.8672 - val_Sox2/counts_loss: 0.2239 - val_Nanog/profile_loss: 559.4864 - val_Nanog/counts_loss: 0.6214
Epoch 7/100
163/163 [==============================] - 80s 490ms/step - loss: 1236.0875 - Oct4/profile_loss: 317.7633 - Oct4/counts_loss: 0.4659 - Sox2/profile_loss: 352.4349 - Sox2/counts_loss: 0.2282 - Nanog/profile_loss: 552.8894 - Nanog/counts_loss: 0.6059 - val_loss: 1256.3485 - val_Oct4/profile_loss: 323.0284 - val_Oct4/counts_loss: 0.4559 - val_Sox2/profile_loss: 355.8593 - val_Sox2/counts_loss: 0.2217 - val_Nanog/profile_loss: 564.6800 - val_Nanog/counts_loss: 0.6005
Epoch 8/100
163/163 [==============================] - 80s 492ms/step - loss: 1235.9194 - Oct4/profile_loss: 317.5627 - Oct4/counts_loss: 0.4558 - Sox2/profile_loss: 352.5775 - Sox2/counts_loss: 0.2225 - Nanog/profile_loss: 553.0057 - Nanog/counts_loss: 0.5991 - val_loss: 1251.9311 - val_Oct4/profile_loss: 321.3869 - val_Oct4/counts_loss: 0.5344 - val_Sox2/profile_loss: 354.8136 - val_Sox2/counts_loss: 0.2343 - val_Nanog/profile_loss: 561.3183 - val_Nanog/counts_loss: 0.6726
Epoch 9/100
163/163 [==============================] - 80s 491ms/step - loss: 1231.9237 - Oct4/profile_loss: 316.9750 - Oct4/counts_loss: 0.4364 - Sox2/profile_loss: 351.9968 - Sox2/counts_loss: 0.2168 - Nanog/profile_loss: 550.5182 - Nanog/counts_loss: 0.5902 - val_loss: 1253.0037 - val_Oct4/profile_loss: 322.0271 - val_Oct4/counts_loss: 0.4407 - val_Sox2/profile_loss: 355.6989 - val_Sox2/counts_loss: 0.2146 - val_Nanog/profile_loss: 562.8966 - val_Nanog/counts_loss: 0.5827
Epoch 10/100
163/163 [==============================] - 80s 492ms/step - loss: 1231.2149 - Oct4/profile_loss: 316.8272 - Oct4/counts_loss: 0.4311 - Sox2/profile_loss: 351.9959 - Sox2/counts_loss: 0.2138 - Nanog/profile_loss: 550.0903 - Nanog/counts_loss: 0.5852 - val_loss: 1249.2393 - val_Oct4/profile_loss: 321.0792 - val_Oct4/counts_loss: 0.4277 - val_Sox2/profile_loss: 354.3677 - val_Sox2/counts_loss: 0.2092 - val_Nanog/profile_loss: 561.5210 - val_Nanog/counts_loss: 0.5902
In [17]:
#%debug
In [18]:
eval_metrics = tr.evaluate(metric=None)  # metric=None -> uses the default head metrics
print(eval_metrics)
2019-02-27 23:49:33,976 [INFO] Evaluating dataset: valid
52it [00:25,  2.03it/s]                        
2019-02-27 23:50:11,212 [INFO] Saved metrics to /tmp/exp/m3-wb/evaluation.valid.json
OrderedDict([('valid', {'Oct4/profile/binsize=1/auprc': nan, 'Oct4/profile/binsize=1/frac_ambigous': 0.1071008230452675, 'Oct4/profile/binsize=1/imbalance': 0.0, 'Oct4/profile/binsize=1/n_positives': 0, 'Oct4/profile/binsize=1/random_auprc': nan, 'Oct4/profile/binsize=10/auprc': nan, 'Oct4/profile/binsize=10/frac_ambigous': 0.37292181069958846, 'Oct4/profile/binsize=10/imbalance': 0.0, 'Oct4/profile/binsize=10/n_positives': 0, 'Oct4/profile/binsize=10/random_auprc': nan, 'Oct4/counts/mse': 0.5072772, 'Oct4/counts/var_explained': -0.007984399795532227, 'Oct4/counts/pearsonr': 0.07223449, 'Oct4/counts/spearmanr': 0.09506672016522448, 'Oct4/counts/mad': 0.5621106, 'Sox2/profile/binsize=1/auprc': 0.0027859028359930576, 'Sox2/profile/binsize=1/frac_ambigous': 0.11189115646258503, 'Sox2/profile/binsize=1/imbalance': 0.0003714994791347509, 'Sox2/profile/binsize=1/n_positives': 97, 'Sox2/profile/binsize=1/random_auprc': 0.0004065170968333897, 'Sox2/profile/binsize=10/auprc': 0.11527502049563305, 'Sox2/profile/binsize=10/frac_ambigous': 0.42149659863945577, 'Sox2/profile/binsize=10/imbalance': 0.0057031984948259645, 'Sox2/profile/binsize=10/n_positives': 97, 'Sox2/profile/binsize=10/random_auprc': 0.006858452999105067, 'Sox2/counts/mse': 0.22738394, 'Sox2/counts/var_explained': 0.0363810658454895, 'Sox2/counts/pearsonr': 0.2928128, 'Sox2/counts/spearmanr': 0.2751861695554017, 'Sox2/counts/mad': 0.37120578, 'Nanog/profile/binsize=1/auprc': 8.582896559906728e-05, 'Nanog/profile/binsize=1/frac_ambigous': 0.06657696945337621, 'Nanog/profile/binsize=1/imbalance': 2.2606348336269646e-05, 'Nanog/profile/binsize=1/n_positives': 105, 'Nanog/profile/binsize=1/random_auprc': 2.557201240095034e-05, 'Nanog/profile/binsize=10/auprc': 0.0009129673749272902, 'Nanog/profile/binsize=10/frac_ambigous': 0.24867363344051446, 'Nanog/profile/binsize=10/imbalance': 0.00015246348900658, 'Nanog/profile/binsize=10/n_positives': 57, 'Nanog/profile/binsize=10/random_auprc': 0.00014593180416573403, 'Nanog/counts/mse': 0.6281811, 'Nanog/counts/var_explained': 0.04202836751937866, 'Nanog/counts/pearsonr': 0.20509867, 'Nanog/counts/spearmanr': 0.1950448578290948, 'Nanog/counts/mad': 0.6277407, 'avg/profile/binsize=1/auprc': nan, 'avg/profile/binsize=1/frac_ambigous': 0.09518964965374292, 'avg/profile/binsize=1/imbalance': 0.00013136860915700685, 'avg/profile/binsize=1/n_positives': 67.33333333333333, 'avg/profile/binsize=1/random_auprc': nan, 'avg/profile/binsize=10/auprc': nan, 'avg/profile/binsize=10/frac_ambigous': 0.3476973475931862, 'avg/profile/binsize=10/imbalance': 0.0019518873279441816, 'avg/profile/binsize=10/n_positives': 51.333333333333336, 'avg/profile/binsize=10/random_auprc': nan, 'avg/counts/mse': 0.4542807439963023, 'avg/counts/var_explained': 0.023475011189778645, 'avg/counts/pearsonr': 0.19004865239063898, 'avg/counts/spearmanr': 0.18843258251657366, 'avg/counts/mad': 0.5203523536523184})])

Get the importance scores

In [19]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
WARNING: Font Arial not installed and is required by
In [20]:
seq_model = tr.seq_model  # extract the seq_model
In [21]:
from genomelake.extractors import FastaExtractor
In [22]:
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
In [23]:
it = valid.batch_iter(batch_size=1, shuffle=True)
In [24]:
batch = next(it)
seq = batch['inputs']['seq']
In [25]:
seq.shape
Out[25]:
(1, 1000, 4)
In [26]:
seq_model.all_heads['Oct4'][0].use_bias
Out[26]:
True
In [27]:
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
WARNING:tensorflow:From /users/amr1/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
2019-02-27 23:50:11,548 [WARNING] From /users/amr1/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
In [28]:
imp_scores.keys()
Out[28]:
dict_keys(['Oct4/profile/wn', 'Oct4/profile/w1', 'Oct4/profile/w2', 'Oct4/profile/winf', 'Oct4/counts/pre-act', 'Sox2/profile/wn', 'Sox2/profile/w1', 'Sox2/profile/w2', 'Sox2/profile/winf', 'Sox2/counts/pre-act', 'Nanog/profile/wn', 'Nanog/profile/w1', 'Nanog/profile/w2', 'Nanog/profile/winf', 'Nanog/counts/pre-act'])
In [29]:
imp_scores['Oct4/profile/wn'].shape
Out[29]:
(1, 1000, 4)
In [30]:
preds = seq_model.predict_preact(seq)
In [31]:
seq.shape
Out[31]:
(1, 1000, 4)

Pred and observed

In [32]:
viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Obs", obs[task]),
                    (f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
                ] for task_idx, task in enumerate(tasks)]))

viz_dict = filter_tracks(viz_dict, [420, 575])
In [33]:
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
        for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
        for feature in ['Imp profile', 'Obs']}


ylim = []
for k in viz_dict:
    f = k.split(" ", 1)[1]
    if "Imp" in f:
        ylim.append((fmin[f], fmax[f]))
    else:
        ylim.append((0, fmax[f]))
In [34]:
fig = plot_tracks(viz_dict,
                  #seqlets=shifted_seqlets,
                  title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                  fig_height_per_track=2,
                  rotate_y=0,
                  fig_width=20,
                  ylim=ylim,
                  legend=False)
In [35]:
train_set = train.load_all(batch_size=32, num_workers=5)
100%|██████████| 1307/1307 [01:50<00:00, 11.87it/s]
In [36]:
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
In [37]:
binsize=50
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
In [39]:
fig, axes = plt.subplots(2, len(tasks), figsize=get_figsize(2/4*len(tasks), 2/len(tasks)))
for i, task in enumerate(train.tasks):
    preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
    log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
                                            axis=(1, 2),  dtype=np.float32))
    cc, p = spearmanr(preds_for_total, log_bias_total_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[0, i]
    matplotlib.rcParams.update({'font.size': 32})
    ax.scatter(preds_for_total, log_bias_total_counts, s=10, c="b", alpha=0.5, marker='o',
               label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_total")
    if i == 0:
        ax.set_ylabel("log_bias_total_counts")
    else:
        ax.set_ylabel("")
    ax.set_title(task)
    
    preds_for_local = np.ravel(np.sum(bin_counts(train_preds[f'{task}/profile'],
                                                binsize=binsize), axis=-1, dtype=np.float32))
    log_bias_local_counts = np.log10(np.ravel(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
                                                     binsize=binsize), axis=-1, dtype=np.float32)))
    cc, p = spearmanr(preds_for_local, log_bias_local_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[1, i]
    ax.scatter(preds_for_local, log_bias_local_counts, s=10, c="b", alpha=0.5, marker='o',
              label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_local")
    if i == 0:
        ax.set_ylabel("log_bias_local_counts")
    else:
        ax.set_ylabel("")