Goals

  • correct-for biases in ChIP-seq

how are model predictions correlated with the bias?

scatter-plot:

  • model predictions for total counts (before adding the bias term) vs bias total counts (log scale only the latter)
  • model predictions for local counts before adding the bias term (in say 50bp bins) {here you bin_counts the predictions} vs bias local counts (log scale only the latter)

basically we hope that our model output doesn't correlate with the bias

In [1]:
# Imports
from basepair.imports import *
import warnings
warnings.filterwarnings('ignore')
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
WARNING: Font Arial not installed and is required by
In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
In [3]:
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-seq.dataspec.yml'
In [4]:
from basepair.datasets import get_StrandedProfile_datasets2

train, valid = get_StrandedProfile_datasets2(
    dataspec=dataspec_path,
    peak_width = 1000,
    seq_width=1000,
    include_metadata = False,
    taskname_first = True,  # so that the output labels will be "{task}/profile"
    exclude_chr = ['chrX', 'chrY'],
    profile_bias_pool_size=50)
In [5]:
valid = valid[0][1]
In [6]:
!cat {dataspec_path}
task_specs:
  Oct4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Oct4/idr-optimal-set.summit.bed.gz
  Sox2:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Sox2/idr-optimal-set.summit.bed.gz
  Nanog:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/Nanog/idr-optimal-set.summit.bed.gz

bias_specs:
  input:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/input-control/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-seq/input-control/counts.neg.bw
    tasks:
      - Oct4
      - Sox2
      - Nanog
    
fasta_file: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/mm10_no_alt_analysis_set_ENCODE.fasta
In [7]:
ds = DataSpec.load(dataspec_path)
In [8]:
tasks = list(ds.task_specs)
In [9]:
tasks
Out[9]:
['Oct4', 'Sox2', 'Nanog']
In [10]:
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
from basepair.plot.evaluate import regression_eval
In [11]:
m = multihead_seq_model(tasks=tasks,
                        filters=64,
                        n_dil_layers=6,
                        conv1_kernel_size=25,tconv_kernel_size=50, 
                        b_loss_weight=0, c_loss_weight=10, p_loss_weight=1, 
                        use_bias=True,
                        lr=0.004, padding='same', batchnorm=False)
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-02-27 23:07:26,306 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-02-27 23:07:37,160 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [12]:
output_dir='/tmp/exp/m3-wb'
In [13]:
!mkdir -p {output_dir}
In [14]:
!rm /tmp/exp/m3-wb/model.h5
rm: cannot remove '/tmp/exp/m3-wb/model.h5': No such file or directory
In [15]:
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
In [16]:
tr.train(epochs=100)
Epoch 1/100
163/163 [==============================] - 100s 612ms/step - loss: 1409.1082 - Oct4/profile_loss: 340.9865 - Oct4/counts_loss: 0.6305 - Sox2/profile_loss: 376.8929 - Sox2/counts_loss: 2.8810 - Nanog/profile_loss: 628.8230 - Nanog/counts_loss: 2.7290 - val_loss: 1350.7902 - val_Oct4/profile_loss: 337.0607 - val_Oct4/counts_loss: 0.4978 - val_Sox2/profile_loss: 366.1328 - val_Sox2/counts_loss: 1.9768 - val_Nanog/profile_loss: 603.5566 - val_Nanog/counts_loss: 1.9294
Epoch 2/100
163/163 [==============================] - 93s 573ms/step - loss: 1298.5329 - Oct4/profile_loss: 328.9437 - Oct4/counts_loss: 0.4792 - Sox2/profile_loss: 361.3166 - Sox2/counts_loss: 0.8536 - Nanog/profile_loss: 584.8409 - Nanog/counts_loss: 1.0103 - val_loss: 1297.1741 - val_Oct4/profile_loss: 328.0854 - val_Oct4/counts_loss: 0.4882 - val_Sox2/profile_loss: 359.5980 - val_Sox2/counts_loss: 1.0019 - val_Nanog/profile_loss: 582.6770 - val_Nanog/counts_loss: 1.1913
Epoch 3/100
163/163 [==============================] - 92s 566ms/step - loss: 1268.7988 - Oct4/profile_loss: 322.4276 - Oct4/counts_loss: 0.4749 - Sox2/profile_loss: 356.1359 - Sox2/counts_loss: 0.7321 - Nanog/profile_loss: 568.8356 - Nanog/counts_loss: 0.9330 - val_loss: 1280.6324 - val_Oct4/profile_loss: 325.9762 - val_Oct4/counts_loss: 0.4798 - val_Sox2/profile_loss: 357.8287 - val_Sox2/counts_loss: 0.7787 - val_Nanog/profile_loss: 574.8751 - val_Nanog/counts_loss: 0.9367
Epoch 4/100
163/163 [==============================] - 92s 566ms/step - loss: 1258.6435 - Oct4/profile_loss: 320.9749 - Oct4/counts_loss: 0.4695 - Sox2/profile_loss: 354.7910 - Sox2/counts_loss: 0.6605 - Nanog/profile_loss: 562.6533 - Nanog/counts_loss: 0.8924 - val_loss: 1272.3416 - val_Oct4/profile_loss: 324.3810 - val_Oct4/counts_loss: 0.4717 - val_Sox2/profile_loss: 356.2002 - val_Sox2/counts_loss: 0.7901 - val_Nanog/profile_loss: 569.2671 - val_Nanog/counts_loss: 0.9875
Epoch 5/100
163/163 [==============================] - 93s 570ms/step - loss: 1249.0319 - Oct4/profile_loss: 319.0524 - Oct4/counts_loss: 0.4616 - Sox2/profile_loss: 353.5332 - Sox2/counts_loss: 0.5871 - Nanog/profile_loss: 557.5971 - Nanog/counts_loss: 0.8363 - val_loss: 1266.4680 - val_Oct4/profile_loss: 323.6355 - val_Oct4/counts_loss: 0.4645 - val_Sox2/profile_loss: 356.5731 - val_Sox2/counts_loss: 0.5977 - val_Nanog/profile_loss: 567.3437 - val_Nanog/counts_loss: 0.8294
Epoch 6/100
163/163 [==============================] - 95s 583ms/step - loss: 1247.2060 - Oct4/profile_loss: 319.1412 - Oct4/counts_loss: 0.4583 - Sox2/profile_loss: 353.5230 - Sox2/counts_loss: 0.5946 - Nanog/profile_loss: 555.4134 - Nanog/counts_loss: 0.8600 - val_loss: 1267.8831 - val_Oct4/profile_loss: 323.9353 - val_Oct4/counts_loss: 0.4581 - val_Sox2/profile_loss: 356.5354 - val_Sox2/counts_loss: 0.6583 - val_Nanog/profile_loss: 566.8567 - val_Nanog/counts_loss: 0.9392
Epoch 7/100
163/163 [==============================] - 91s 560ms/step - loss: 1240.5010 - Oct4/profile_loss: 317.9671 - Oct4/counts_loss: 0.4487 - Sox2/profile_loss: 352.8163 - Sox2/counts_loss: 0.5085 - Nanog/profile_loss: 552.2398 - Nanog/counts_loss: 0.7906 - val_loss: 1258.6785 - val_Oct4/profile_loss: 322.0693 - val_Oct4/counts_loss: 0.4493 - val_Sox2/profile_loss: 355.1304 - val_Sox2/counts_loss: 0.5149 - val_Nanog/profile_loss: 563.8883 - val_Nanog/counts_loss: 0.7948
Epoch 8/100
163/163 [==============================] - 90s 552ms/step - loss: 1237.8967 - Oct4/profile_loss: 317.3567 - Oct4/counts_loss: 0.4446 - Sox2/profile_loss: 352.4240 - Sox2/counts_loss: 0.4730 - Nanog/profile_loss: 551.2338 - Nanog/counts_loss: 0.7706 - val_loss: 1258.9126 - val_Oct4/profile_loss: 322.7872 - val_Oct4/counts_loss: 0.4436 - val_Sox2/profile_loss: 355.6591 - val_Sox2/counts_loss: 0.4895 - val_Nanog/profile_loss: 563.4535 - val_Nanog/counts_loss: 0.7682
Epoch 9/100
163/163 [==============================] - 93s 573ms/step - loss: 1236.6143 - Oct4/profile_loss: 317.3014 - Oct4/counts_loss: 0.4388 - Sox2/profile_loss: 352.5816 - Sox2/counts_loss: 0.4434 - Nanog/profile_loss: 550.4262 - Nanog/counts_loss: 0.7483 - val_loss: 1257.2230 - val_Oct4/profile_loss: 322.5825 - val_Oct4/counts_loss: 0.4380 - val_Sox2/profile_loss: 355.0676 - val_Sox2/counts_loss: 0.4285 - val_Nanog/profile_loss: 563.7002 - val_Nanog/counts_loss: 0.7208
Epoch 10/100
163/163 [==============================] - 94s 577ms/step - loss: 1232.2643 - Oct4/profile_loss: 316.7986 - Oct4/counts_loss: 0.4308 - Sox2/profile_loss: 351.8690 - Sox2/counts_loss: 0.3957 - Nanog/profile_loss: 548.2363 - Nanog/counts_loss: 0.7095 - val_loss: 1256.7176 - val_Oct4/profile_loss: 322.7922 - val_Oct4/counts_loss: 0.4330 - val_Sox2/profile_loss: 355.9063 - val_Sox2/counts_loss: 0.3771 - val_Nanog/profile_loss: 562.9331 - val_Nanog/counts_loss: 0.6985
Epoch 11/100
163/163 [==============================] - 91s 556ms/step - loss: 1231.7976 - Oct4/profile_loss: 316.5949 - Oct4/counts_loss: 0.4261 - Sox2/profile_loss: 352.1184 - Sox2/counts_loss: 0.3691 - Nanog/profile_loss: 548.1867 - Nanog/counts_loss: 0.6945 - val_loss: 1251.4352 - val_Oct4/profile_loss: 320.1111 - val_Oct4/counts_loss: 0.4238 - val_Sox2/profile_loss: 354.5254 - val_Sox2/counts_loss: 0.3657 - val_Nanog/profile_loss: 562.0920 - val_Nanog/counts_loss: 0.6812
Epoch 12/100
163/163 [==============================] - 91s 560ms/step - loss: 1228.3451 - Oct4/profile_loss: 316.0978 - Oct4/counts_loss: 0.4210 - Sox2/profile_loss: 351.4782 - Sox2/counts_loss: 0.3445 - Nanog/profile_loss: 546.3999 - Nanog/counts_loss: 0.6714 - val_loss: 1251.6111 - val_Oct4/profile_loss: 321.3332 - val_Oct4/counts_loss: 0.4221 - val_Sox2/profile_loss: 355.0525 - val_Sox2/counts_loss: 0.3374 - val_Nanog/profile_loss: 560.9384 - val_Nanog/counts_loss: 0.6692
Epoch 13/100
163/163 [==============================] - 99s 608ms/step - loss: 1226.5174 - Oct4/profile_loss: 315.7577 - Oct4/counts_loss: 0.4134 - Sox2/profile_loss: 351.4087 - Sox2/counts_loss: 0.3225 - Nanog/profile_loss: 545.3712 - Nanog/counts_loss: 0.6621 - val_loss: 1253.0739 - val_Oct4/profile_loss: 321.6195 - val_Oct4/counts_loss: 0.4110 - val_Sox2/profile_loss: 354.7605 - val_Sox2/counts_loss: 0.3300 - val_Nanog/profile_loss: 562.4547 - val_Nanog/counts_loss: 0.6829
Epoch 14/100
163/163 [==============================] - 100s 615ms/step - loss: 1224.7103 - Oct4/profile_loss: 315.4180 - Oct4/counts_loss: 0.4095 - Sox2/profile_loss: 351.1617 - Sox2/counts_loss: 0.2982 - Nanog/profile_loss: 544.6189 - Nanog/counts_loss: 0.6434 - val_loss: 1250.9987 - val_Oct4/profile_loss: 321.5569 - val_Oct4/counts_loss: 0.4146 - val_Sox2/profile_loss: 354.8048 - val_Sox2/counts_loss: 0.2845 - val_Nanog/profile_loss: 561.4698 - val_Nanog/counts_loss: 0.6177
Epoch 15/100
163/163 [==============================] - 101s 617ms/step - loss: 1223.8045 - Oct4/profile_loss: 315.9277 - Oct4/counts_loss: 0.4022 - Sox2/profile_loss: 351.1150 - Sox2/counts_loss: 0.2753 - Nanog/profile_loss: 543.8324 - Nanog/counts_loss: 0.6155 - val_loss: 1248.7559 - val_Oct4/profile_loss: 320.6742 - val_Oct4/counts_loss: 0.4155 - val_Sox2/profile_loss: 354.4938 - val_Sox2/counts_loss: 0.2732 - val_Nanog/profile_loss: 560.5940 - val_Nanog/counts_loss: 0.6107
Epoch 16/100
163/163 [==============================] - 99s 605ms/step - loss: 1222.8327 - Oct4/profile_loss: 315.4618 - Oct4/counts_loss: 0.3993 - Sox2/profile_loss: 351.1285 - Sox2/counts_loss: 0.2714 - Nanog/profile_loss: 543.3295 - Nanog/counts_loss: 0.6206 - val_loss: 1253.7526 - val_Oct4/profile_loss: 321.7595 - val_Oct4/counts_loss: 0.4008 - val_Sox2/profile_loss: 355.1321 - val_Sox2/counts_loss: 0.2641 - val_Nanog/profile_loss: 564.2945 - val_Nanog/counts_loss: 0.5917
Epoch 17/100
163/163 [==============================] - 92s 565ms/step - loss: 1217.9581 - Oct4/profile_loss: 314.4696 - Oct4/counts_loss: 0.3901 - Sox2/profile_loss: 350.2056 - Sox2/counts_loss: 0.2470 - Nanog/profile_loss: 540.9778 - Nanog/counts_loss: 0.5934 - val_loss: 1255.1750 - val_Oct4/profile_loss: 321.6718 - val_Oct4/counts_loss: 0.3906 - val_Sox2/profile_loss: 355.7823 - val_Sox2/counts_loss: 0.2594 - val_Nanog/profile_loss: 565.0459 - val_Nanog/counts_loss: 0.6175
Epoch 18/100
163/163 [==============================] - 94s 578ms/step - loss: 1218.9724 - Oct4/profile_loss: 314.7826 - Oct4/counts_loss: 0.3863 - Sox2/profile_loss: 350.5734 - Sox2/counts_loss: 0.2406 - Nanog/profile_loss: 541.3959 - Nanog/counts_loss: 0.5952 - val_loss: 1250.0182 - val_Oct4/profile_loss: 320.7275 - val_Oct4/counts_loss: 0.3844 - val_Sox2/profile_loss: 354.5637 - val_Sox2/counts_loss: 0.2381 - val_Nanog/profile_loss: 562.7131 - val_Nanog/counts_loss: 0.5789
Epoch 19/100
163/163 [==============================] - 98s 600ms/step - loss: 1218.7746 - Oct4/profile_loss: 314.7897 - Oct4/counts_loss: 0.3803 - Sox2/profile_loss: 350.7729 - Sox2/counts_loss: 0.2235 - Nanog/profile_loss: 541.3990 - Nanog/counts_loss: 0.5775 - val_loss: 1251.1286 - val_Oct4/profile_loss: 322.0096 - val_Oct4/counts_loss: 0.3828 - val_Sox2/profile_loss: 354.7087 - val_Sox2/counts_loss: 0.2271 - val_Nanog/profile_loss: 562.4199 - val_Nanog/counts_loss: 0.5890
In [17]:
#%debug
In [18]:
eval_metrics = tr.evaluate(metric=None)  # metric=None -> uses the default head metrics
print(eval_metrics)
2019-02-27 23:37:52,680 [INFO] Evaluating dataset: valid
52it [00:26,  1.96it/s]                        
2019-02-27 23:38:32,044 [INFO] Saved metrics to /tmp/exp/m3-wb/evaluation.valid.json
OrderedDict([('valid', {'Oct4/profile/binsize=1/auprc': nan, 'Oct4/profile/binsize=1/frac_ambigous': 0.1071008230452675, 'Oct4/profile/binsize=1/imbalance': 0.0, 'Oct4/profile/binsize=1/n_positives': 0, 'Oct4/profile/binsize=1/random_auprc': nan, 'Oct4/profile/binsize=10/auprc': nan, 'Oct4/profile/binsize=10/frac_ambigous': 0.37292181069958846, 'Oct4/profile/binsize=10/imbalance': 0.0, 'Oct4/profile/binsize=10/n_positives': 0, 'Oct4/profile/binsize=10/random_auprc': nan, 'Oct4/counts/mse': 0.4144217, 'Oct4/counts/var_explained': 0.13405901193618774, 'Oct4/counts/pearsonr': 0.37713197, 'Oct4/counts/spearmanr': 0.3708756991738943, 'Oct4/counts/mad': 0.50533843, 'Sox2/profile/binsize=1/auprc': 0.0029206681309995575, 'Sox2/profile/binsize=1/frac_ambigous': 0.11189115646258503, 'Sox2/profile/binsize=1/imbalance': 0.0003714994791347509, 'Sox2/profile/binsize=1/n_positives': 97, 'Sox2/profile/binsize=1/random_auprc': 0.0004464317950030088, 'Sox2/profile/binsize=10/auprc': 0.1380789501703978, 'Sox2/profile/binsize=10/frac_ambigous': 0.42149659863945577, 'Sox2/profile/binsize=10/imbalance': 0.0057031984948259645, 'Sox2/profile/binsize=10/n_positives': 97, 'Sox2/profile/binsize=10/random_auprc': 0.005810156108111966, 'Sox2/counts/mse': 0.27205563, 'Sox2/counts/var_explained': -0.14232099056243896, 'Sox2/counts/pearsonr': 0.06327891, 'Sox2/counts/spearmanr': 0.10930730118969495, 'Sox2/counts/mad': 0.38956323, 'Nanog/profile/binsize=1/auprc': 5.4847291730822176e-05, 'Nanog/profile/binsize=1/frac_ambigous': 0.06657696945337621, 'Nanog/profile/binsize=1/imbalance': 2.2606348336269646e-05, 'Nanog/profile/binsize=1/n_positives': 105, 'Nanog/profile/binsize=1/random_auprc': 2.323210601250873e-05, 'Nanog/profile/binsize=10/auprc': 0.000589836151978855, 'Nanog/profile/binsize=10/frac_ambigous': 0.24867363344051446, 'Nanog/profile/binsize=10/imbalance': 0.00015246348900658, 'Nanog/profile/binsize=10/n_positives': 57, 'Nanog/profile/binsize=10/random_auprc': 0.00013819208371292546, 'Nanog/counts/mse': 0.60470206, 'Nanog/counts/var_explained': 0.04283881187438965, 'Nanog/counts/pearsonr': 0.22235757, 'Nanog/counts/spearmanr': 0.20974558023976633, 'Nanog/counts/mad': 0.63093126, 'avg/profile/binsize=1/auprc': nan, 'avg/profile/binsize=1/frac_ambigous': 0.09518964965374292, 'avg/profile/binsize=1/imbalance': 0.00013136860915700685, 'avg/profile/binsize=1/n_positives': 67.33333333333333, 'avg/profile/binsize=1/random_auprc': nan, 'avg/profile/binsize=10/auprc': nan, 'avg/profile/binsize=10/frac_ambigous': 0.3476973475931862, 'avg/profile/binsize=10/imbalance': 0.0019518873279441816, 'avg/profile/binsize=10/n_positives': 51.333333333333336, 'avg/profile/binsize=10/random_auprc': nan, 'avg/counts/mse': 0.43039312958717346, 'avg/counts/var_explained': 0.011525611082712809, 'avg/counts/pearsonr': 0.22092281778653464, 'avg/counts/spearmanr': 0.22997619353445187, 'avg/counts/mad': 0.5086109737555186})])

Get the importance scores

In [19]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
WARNING: Font Arial not installed and is required by
In [20]:
seq_model = tr.seq_model  # extract the seq_model
In [21]:
from genomelake.extractors import FastaExtractor
In [22]:
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
In [23]:
it = valid.batch_iter(batch_size=1, shuffle=True)
In [24]:
batch = next(it)
seq = batch['inputs']['seq']
In [25]:
seq.shape
Out[25]:
(1, 1000, 4)
In [26]:
seq_model.all_heads['Oct4'][0].use_bias
Out[26]:
True
In [27]:
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
WARNING:tensorflow:From /users/amr1/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
2019-02-27 23:38:32,440 [WARNING] From /users/amr1/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
In [28]:
imp_scores.keys()
Out[28]:
dict_keys(['Oct4/profile/wn', 'Oct4/profile/w1', 'Oct4/profile/w2', 'Oct4/profile/winf', 'Oct4/counts/pre-act', 'Sox2/profile/wn', 'Sox2/profile/w1', 'Sox2/profile/w2', 'Sox2/profile/winf', 'Sox2/counts/pre-act', 'Nanog/profile/wn', 'Nanog/profile/w1', 'Nanog/profile/w2', 'Nanog/profile/winf', 'Nanog/counts/pre-act'])
In [29]:
imp_scores['Oct4/profile/wn'].shape
Out[29]:
(1, 1000, 4)
In [30]:
preds = seq_model.predict_preact(seq)
In [31]:
seq.shape
Out[31]:
(1, 1000, 4)

Pred and observed

In [32]:
viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Obs", obs[task]),
                    (f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
                ] for task_idx, task in enumerate(tasks)]))

viz_dict = filter_tracks(viz_dict, [420, 575])
In [33]:
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
        for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
        for feature in ['Imp profile', 'Obs']}


ylim = []
for k in viz_dict:
    f = k.split(" ", 1)[1]
    if "Imp" in f:
        ylim.append((fmin[f], fmax[f]))
    else:
        ylim.append((0, fmax[f]))
In [34]:
fig = plot_tracks(viz_dict,
                  #seqlets=shifted_seqlets,
                  title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                  fig_height_per_track=2,
                  rotate_y=0,
                  fig_width=20,
                  ylim=ylim,
                  legend=False)
In [35]:
train_set = train.load_all(batch_size=32, num_workers=5)
100%|██████████| 1307/1307 [02:06<00:00, 10.35it/s]
In [36]:
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
In [37]:
binsize=50
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
In [44]:
fig, axes = plt.subplots(2, len(tasks), figsize=get_figsize(2/4*len(tasks), 2/len(tasks)))
for i, task in enumerate(train.tasks):
    preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
    log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
                                            axis=(1, 2),  dtype=np.float32))
    cc, p = spearmanr(preds_for_total, log_bias_total_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[0, i]
    matplotlib.rcParams.update({'font.size': 32})
    ax.scatter(preds_for_total, log_bias_total_counts, s=10, c="b", alpha=0.5, marker='o',
               label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_total")
    if i == 0:
        ax.set_ylabel("log_bias_total_counts")
    else:
        ax.set_ylabel("")
    ax.set_title(task)
    
    preds_for_local = np.ravel(np.sum(bin_counts(train_preds[f'{task}/profile'],
                                                binsize=binsize), axis=-1, dtype=np.float32))
    log_bias_local_counts = np.log10(np.ravel(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
                                                     binsize=binsize), axis=-1, dtype=np.float32)))
    cc, p = spearmanr(preds_for_local, log_bias_local_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[1, i]
    ax.scatter(preds_for_local, log_bias_local_counts, s=10, c="b", alpha=0.5, marker='o',
              label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_local")
    if i == 0:
        ax.set_ylabel("log_bias_local_counts")
    else:
        ax.set_ylabel("")