Goals

  • correct-for biases in ChIP-seq

how are model predictions correlated with the bias?

scatter-plot:

  • model predictions for total counts (before adding the bias term) vs bias total counts (log scale only the latter)
  • model predictions for local counts before adding the bias term (in say 50bp bins) {here you bin_counts the predictions} vs bias local counts (log scale only the latter)

basically we hope that our model output doesn't correlate with the bias

In [1]:
# Imports
from basepair.imports import *
import warnings
warnings.filterwarnings('ignore')
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
2019-02-28 00:54:16,367 [WARNING] git-lfs not installed
WARNING: Font Arial not installed and is required by
In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 4, 6"
In [3]:
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml'
In [4]:
from basepair.datasets import get_StrandedProfile_datasets2

train, valid = get_StrandedProfile_datasets2(
    dataspec=dataspec_path,
    peak_width = 1000,
    seq_width=1000,
    include_metadata = False,
    taskname_first = True,  # so that the output labels will be "{task}/profile"
    exclude_chr = ['chrX', 'chrY'],
    profile_bias_pool_size=None)
In [5]:
valid = valid[0][1]
In [6]:
!cat {dataspec_path}
task_specs:
  Oct4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/idr-optimal-set.summit.bed.gz
  Sox2:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/idr-optimal-set.summit.bed.gz
  Nanog:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/idr-optimal-set.summit.bed.gz
  Klf4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/idr-optimal-set.summit.bed.gz

bias_specs:
  input:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/patchcap/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/patchcap/counts.neg.bw
    tasks:
      - Oct4
      - Sox2
      - Nanog
      - Klf4

fasta_file: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/mm10_no_alt_analysis_set_ENCODE.fasta
In [7]:
ds = DataSpec.load(dataspec_path)
In [8]:
tasks = list(ds.task_specs)
In [9]:
tasks
Out[9]:
['Oct4', 'Sox2', 'Nanog', 'Klf4']
In [10]:
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
from basepair.plot.evaluate import regression_eval
In [11]:
m = multihead_seq_model(tasks=tasks,
                        filters=64,
                        n_dil_layers=9,
                        conv1_kernel_size=25,tconv_kernel_size=25, 
                        b_loss_weight=0, c_loss_weight=10, p_loss_weight=1, 
                        use_bias=True,
                        lr=0.004, padding='same', batchnorm=False)
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-02-28 00:54:21,856 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-02-28 00:54:32,728 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [12]:
output_dir='/tmp/exp/m3-wb'
In [13]:
!mkdir -p {output_dir}
In [14]:
!rm /tmp/exp/m3-wb/model.h5
rm: cannot remove '/tmp/exp/m3-wb/model.h5': No such file or directory
In [15]:
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
In [16]:
tr.train(epochs=100)
Epoch 1/100
355/355 [==============================] - 304s 858ms/step - loss: 3389.6669 - Oct4/profile_loss: 842.2641 - Oct4/counts_loss: 0.6645 - Sox2/profile_loss: 466.5606 - Sox2/counts_loss: 1.2903 - Nanog/profile_loss: 1107.2829 - Nanog/counts_loss: 1.9055 - Klf4/profile_loss: 925.5513 - Klf4/counts_loss: 0.9405 - val_loss: 3325.9709 - val_Oct4/profile_loss: 833.2521 - val_Oct4/counts_loss: 0.5210 - val_Sox2/profile_loss: 468.1291 - val_Sox2/counts_loss: 0.9726 - val_Nanog/profile_loss: 1085.9972 - val_Nanog/counts_loss: 1.2311 - val_Klf4/profile_loss: 904.2933 - val_Klf4/counts_loss: 0.7052
Epoch 2/100
355/355 [==============================] - 332s 936ms/step - loss: 3261.3870 - Oct4/profile_loss: 826.0909 - Oct4/counts_loss: 0.4977 - Sox2/profile_loss: 460.0636 - Sox2/counts_loss: 0.6587 - Nanog/profile_loss: 1043.0790 - Nanog/counts_loss: 1.0774 - Klf4/profile_loss: 903.3053 - Klf4/counts_loss: 0.6511 - val_loss: 3266.5104 - val_Oct4/profile_loss: 826.8660 - val_Oct4/counts_loss: 0.4566 - val_Sox2/profile_loss: 465.4070 - val_Sox2/counts_loss: 0.5796 - val_Nanog/profile_loss: 1050.9433 - val_Nanog/counts_loss: 0.9616 - val_Klf4/profile_loss: 897.3054 - val_Klf4/counts_loss: 0.6011
Epoch 3/100
355/355 [==============================] - 338s 951ms/step - loss: 3230.3713 - Oct4/profile_loss: 821.9681 - Oct4/counts_loss: 0.4646 - Sox2/profile_loss: 457.4972 - Sox2/counts_loss: 0.4934 - Nanog/profile_loss: 1027.9464 - Nanog/counts_loss: 0.9538 - Klf4/profile_loss: 897.8750 - Klf4/counts_loss: 0.5967 - val_loss: 3249.1786 - val_Oct4/profile_loss: 823.7564 - val_Oct4/counts_loss: 0.4469 - val_Sox2/profile_loss: 463.0988 - val_Sox2/counts_loss: 0.4162 - val_Nanog/profile_loss: 1045.4647 - val_Nanog/counts_loss: 0.8392 - val_Klf4/profile_loss: 893.9879 - val_Klf4/counts_loss: 0.5848
Epoch 4/100
355/355 [==============================] - 322s 907ms/step - loss: 3201.7717 - Oct4/profile_loss: 817.5797 - Oct4/counts_loss: 0.4348 - Sox2/profile_loss: 454.2683 - Sox2/counts_loss: 0.3803 - Nanog/profile_loss: 1017.4447 - Nanog/counts_loss: 0.8500 - Klf4/profile_loss: 890.2526 - Klf4/counts_loss: 0.5575 - val_loss: 3220.9606 - val_Oct4/profile_loss: 818.6632 - val_Oct4/counts_loss: 0.4114 - val_Sox2/profile_loss: 460.6107 - val_Sox2/counts_loss: 0.3726 - val_Nanog/profile_loss: 1035.3290 - val_Nanog/counts_loss: 0.8274 - val_Klf4/profile_loss: 885.0898 - val_Klf4/counts_loss: 0.5154
Epoch 5/100
355/355 [==============================] - 341s 961ms/step - loss: 3179.0727 - Oct4/profile_loss: 812.2971 - Oct4/counts_loss: 0.4082 - Sox2/profile_loss: 452.4684 - Sox2/counts_loss: 0.3102 - Nanog/profile_loss: 1009.4461 - Nanog/counts_loss: 0.7731 - Klf4/profile_loss: 884.7532 - Klf4/counts_loss: 0.5194 - val_loss: 3204.8840 - val_Oct4/profile_loss: 814.5301 - val_Oct4/counts_loss: 0.4036 - val_Sox2/profile_loss: 459.0520 - val_Sox2/counts_loss: 0.3067 - val_Nanog/profile_loss: 1028.8022 - val_Nanog/counts_loss: 0.7369 - val_Klf4/profile_loss: 882.7784 - val_Klf4/counts_loss: 0.5249
Epoch 6/100
355/355 [==============================] - 330s 930ms/step - loss: 3164.7743 - Oct4/profile_loss: 808.8898 - Oct4/counts_loss: 0.3820 - Sox2/profile_loss: 451.3949 - Sox2/counts_loss: 0.2752 - Nanog/profile_loss: 1002.3759 - Nanog/counts_loss: 0.7379 - Klf4/profile_loss: 883.2044 - Klf4/counts_loss: 0.4957 - val_loss: 3199.9032 - val_Oct4/profile_loss: 814.0785 - val_Oct4/counts_loss: 0.3813 - val_Sox2/profile_loss: 459.4840 - val_Sox2/counts_loss: 0.3085 - val_Nanog/profile_loss: 1024.7675 - val_Nanog/counts_loss: 0.7406 - val_Klf4/profile_loss: 882.5083 - val_Klf4/counts_loss: 0.4760
Epoch 7/100
355/355 [==============================] - 315s 887ms/step - loss: 3154.1298 - Oct4/profile_loss: 805.9319 - Oct4/counts_loss: 0.3547 - Sox2/profile_loss: 450.8794 - Sox2/counts_loss: 0.2476 - Nanog/profile_loss: 997.6222 - Nanog/counts_loss: 0.6916 - Klf4/profile_loss: 881.9445 - Klf4/counts_loss: 0.4813 - val_loss: 3191.5360 - val_Oct4/profile_loss: 809.8232 - val_Oct4/counts_loss: 0.3315 - val_Sox2/profile_loss: 457.7474 - val_Sox2/counts_loss: 0.2335 - val_Nanog/profile_loss: 1025.6768 - val_Nanog/counts_loss: 0.7005 - val_Klf4/profile_loss: 880.6773 - val_Klf4/counts_loss: 0.4956
Epoch 8/100
355/355 [==============================] - 295s 831ms/step - loss: 3140.7919 - Oct4/profile_loss: 801.7694 - Oct4/counts_loss: 0.3302 - Sox2/profile_loss: 450.1052 - Sox2/counts_loss: 0.2297 - Nanog/profile_loss: 991.8798 - Nanog/counts_loss: 0.6625 - Klf4/profile_loss: 880.1622 - Klf4/counts_loss: 0.4650 - val_loss: 3186.3651 - val_Oct4/profile_loss: 810.1372 - val_Oct4/counts_loss: 0.3155 - val_Sox2/profile_loss: 458.5401 - val_Sox2/counts_loss: 0.2232 - val_Nanog/profile_loss: 1021.8779 - val_Nanog/counts_loss: 0.6841 - val_Klf4/profile_loss: 879.1737 - val_Klf4/counts_loss: 0.4408
Epoch 9/100
355/355 [==============================] - 304s 855ms/step - loss: 3135.6923 - Oct4/profile_loss: 800.3921 - Oct4/counts_loss: 0.3194 - Sox2/profile_loss: 449.9242 - Sox2/counts_loss: 0.2261 - Nanog/profile_loss: 989.1133 - Nanog/counts_loss: 0.6632 - Klf4/profile_loss: 879.6023 - Klf4/counts_loss: 0.4574 - val_loss: 3184.2566 - val_Oct4/profile_loss: 807.7639 - val_Oct4/counts_loss: 0.4198 - val_Sox2/profile_loss: 457.1672 - val_Sox2/counts_loss: 0.2875 - val_Nanog/profile_loss: 1020.7835 - val_Nanog/counts_loss: 0.6950 - val_Klf4/profile_loss: 878.6566 - val_Klf4/counts_loss: 0.5863
Epoch 10/100
355/355 [==============================] - 295s 831ms/step - loss: 3127.3713 - Oct4/profile_loss: 798.5103 - Oct4/counts_loss: 0.3090 - Sox2/profile_loss: 449.3202 - Sox2/counts_loss: 0.2182 - Nanog/profile_loss: 985.0993 - Nanog/counts_loss: 0.6383 - Klf4/profile_loss: 878.2927 - Klf4/counts_loss: 0.4494 - val_loss: 3175.7718 - val_Oct4/profile_loss: 807.7467 - val_Oct4/counts_loss: 0.3041 - val_Sox2/profile_loss: 457.1615 - val_Sox2/counts_loss: 0.2147 - val_Nanog/profile_loss: 1015.7222 - val_Nanog/counts_loss: 0.6378 - val_Klf4/profile_loss: 879.2670 - val_Klf4/counts_loss: 0.4309
Epoch 11/100
355/355 [==============================] - 278s 783ms/step - loss: 3123.3708 - Oct4/profile_loss: 798.1409 - Oct4/counts_loss: 0.3011 - Sox2/profile_loss: 449.1353 - Sox2/counts_loss: 0.2116 - Nanog/profile_loss: 981.8099 - Nanog/counts_loss: 0.6170 - Klf4/profile_loss: 878.5345 - Klf4/counts_loss: 0.4453 - val_loss: 3176.7018 - val_Oct4/profile_loss: 808.4095 - val_Oct4/counts_loss: 0.3689 - val_Sox2/profile_loss: 457.8434 - val_Sox2/counts_loss: 0.2686 - val_Nanog/profile_loss: 1014.2842 - val_Nanog/counts_loss: 0.6535 - val_Klf4/profile_loss: 878.2822 - val_Klf4/counts_loss: 0.4972
Epoch 12/100
355/355 [==============================] - 271s 763ms/step - loss: 3116.6314 - Oct4/profile_loss: 796.6006 - Oct4/counts_loss: 0.2949 - Sox2/profile_loss: 448.6071 - Sox2/counts_loss: 0.2105 - Nanog/profile_loss: 979.0698 - Nanog/counts_loss: 0.6038 - Klf4/profile_loss: 876.9328 - Klf4/counts_loss: 0.4330 - val_loss: 3175.3735 - val_Oct4/profile_loss: 807.7516 - val_Oct4/counts_loss: 0.3030 - val_Sox2/profile_loss: 457.4699 - val_Sox2/counts_loss: 0.2157 - val_Nanog/profile_loss: 1016.5674 - val_Nanog/counts_loss: 0.6015 - val_Klf4/profile_loss: 878.2251 - val_Klf4/counts_loss: 0.4157
Epoch 13/100
355/355 [==============================] - 292s 822ms/step - loss: 3113.7557 - Oct4/profile_loss: 796.4270 - Oct4/counts_loss: 0.2876 - Sox2/profile_loss: 448.6123 - Sox2/counts_loss: 0.2056 - Nanog/profile_loss: 977.1403 - Nanog/counts_loss: 0.5907 - Klf4/profile_loss: 876.5674 - Klf4/counts_loss: 0.4169 - val_loss: 3175.1636 - val_Oct4/profile_loss: 807.6774 - val_Oct4/counts_loss: 0.2898 - val_Sox2/profile_loss: 457.4953 - val_Sox2/counts_loss: 0.2091 - val_Nanog/profile_loss: 1016.8700 - val_Nanog/counts_loss: 0.5923 - val_Klf4/profile_loss: 878.0472 - val_Klf4/counts_loss: 0.4162
Epoch 14/100
355/355 [==============================] - 332s 935ms/step - loss: 3110.1870 - Oct4/profile_loss: 795.8236 - Oct4/counts_loss: 0.2832 - Sox2/profile_loss: 448.1657 - Sox2/counts_loss: 0.2045 - Nanog/profile_loss: 974.8956 - Nanog/counts_loss: 0.5858 - Klf4/profile_loss: 876.4125 - Klf4/counts_loss: 0.4154 - val_loss: 3172.9572 - val_Oct4/profile_loss: 806.7719 - val_Oct4/counts_loss: 0.2969 - val_Sox2/profile_loss: 456.6316 - val_Sox2/counts_loss: 0.2073 - val_Nanog/profile_loss: 1016.1903 - val_Nanog/counts_loss: 0.6277 - val_Klf4/profile_loss: 877.8077 - val_Klf4/counts_loss: 0.4236
Epoch 15/100
355/355 [==============================] - 280s 787ms/step - loss: 3105.7419 - Oct4/profile_loss: 795.1627 - Oct4/counts_loss: 0.2773 - Sox2/profile_loss: 448.1386 - Sox2/counts_loss: 0.2025 - Nanog/profile_loss: 972.4154 - Nanog/counts_loss: 0.5752 - Klf4/profile_loss: 875.4388 - Klf4/counts_loss: 0.4036 - val_loss: 3172.7137 - val_Oct4/profile_loss: 807.3202 - val_Oct4/counts_loss: 0.2739 - val_Sox2/profile_loss: 456.9684 - val_Sox2/counts_loss: 0.2056 - val_Nanog/profile_loss: 1016.7330 - val_Nanog/counts_loss: 0.5741 - val_Klf4/profile_loss: 877.3036 - val_Klf4/counts_loss: 0.3852
Epoch 16/100
355/355 [==============================] - 295s 831ms/step - loss: 3105.5972 - Oct4/profile_loss: 794.6833 - Oct4/counts_loss: 0.2730 - Sox2/profile_loss: 448.0970 - Sox2/counts_loss: 0.1985 - Nanog/profile_loss: 972.3295 - Nanog/counts_loss: 0.5570 - Klf4/profile_loss: 876.2045 - Klf4/counts_loss: 0.3997 - val_loss: 3167.5106 - val_Oct4/profile_loss: 807.3943 - val_Oct4/counts_loss: 0.3134 - val_Sox2/profile_loss: 457.0907 - val_Sox2/counts_loss: 0.2040 - val_Nanog/profile_loss: 1010.6578 - val_Nanog/counts_loss: 0.6000 - val_Klf4/profile_loss: 877.3374 - val_Klf4/counts_loss: 0.3856
Epoch 17/100
355/355 [==============================] - 282s 795ms/step - loss: 3100.1499 - Oct4/profile_loss: 793.9257 - Oct4/counts_loss: 0.2661 - Sox2/profile_loss: 447.7924 - Sox2/counts_loss: 0.1949 - Nanog/profile_loss: 970.0066 - Nanog/counts_loss: 0.5479 - Klf4/profile_loss: 874.4411 - Klf4/counts_loss: 0.3895 - val_loss: 3179.8758 - val_Oct4/profile_loss: 807.2529 - val_Oct4/counts_loss: 0.3364 - val_Sox2/profile_loss: 457.2615 - val_Sox2/counts_loss: 0.2301 - val_Nanog/profile_loss: 1020.5329 - val_Nanog/counts_loss: 0.5832 - val_Klf4/profile_loss: 878.7312 - val_Klf4/counts_loss: 0.4600
Epoch 18/100
355/355 [==============================] - 294s 827ms/step - loss: 3099.2059 - Oct4/profile_loss: 794.2802 - Oct4/counts_loss: 0.2604 - Sox2/profile_loss: 447.7754 - Sox2/counts_loss: 0.1926 - Nanog/profile_loss: 968.4587 - Nanog/counts_loss: 0.5329 - Klf4/profile_loss: 874.9693 - Klf4/counts_loss: 0.3863 - val_loss: 3171.8861 - val_Oct4/profile_loss: 808.5212 - val_Oct4/counts_loss: 0.3170 - val_Sox2/profile_loss: 457.7331 - val_Sox2/counts_loss: 0.2453 - val_Nanog/profile_loss: 1013.1652 - val_Nanog/counts_loss: 0.6000 - val_Klf4/profile_loss: 875.9972 - val_Klf4/counts_loss: 0.4846
Epoch 19/100
355/355 [==============================] - 277s 779ms/step - loss: 3096.9057 - Oct4/profile_loss: 793.5210 - Oct4/counts_loss: 0.2576 - Sox2/profile_loss: 447.7834 - Sox2/counts_loss: 0.1921 - Nanog/profile_loss: 967.8509 - Nanog/counts_loss: 0.5295 - Klf4/profile_loss: 874.1481 - Klf4/counts_loss: 0.3810 - val_loss: 3171.0293 - val_Oct4/profile_loss: 807.3552 - val_Oct4/counts_loss: 0.2861 - val_Sox2/profile_loss: 456.3391 - val_Sox2/counts_loss: 0.2363 - val_Nanog/profile_loss: 1015.2711 - val_Nanog/counts_loss: 0.6179 - val_Klf4/profile_loss: 876.1881 - val_Klf4/counts_loss: 0.4473
Epoch 20/100
355/355 [==============================] - 291s 820ms/step - loss: 3091.8675 - Oct4/profile_loss: 792.7794 - Oct4/counts_loss: 0.2534 - Sox2/profile_loss: 447.1234 - Sox2/counts_loss: 0.1898 - Nanog/profile_loss: 965.1960 - Nanog/counts_loss: 0.5182 - Klf4/profile_loss: 873.3432 - Klf4/counts_loss: 0.3812 - val_loss: 3168.4297 - val_Oct4/profile_loss: 807.5179 - val_Oct4/counts_loss: 0.2592 - val_Sox2/profile_loss: 457.3833 - val_Sox2/counts_loss: 0.2030 - val_Nanog/profile_loss: 1010.6274 - val_Nanog/counts_loss: 0.5620 - val_Klf4/profile_loss: 878.5093 - val_Klf4/counts_loss: 0.4149
---------------------------------------------------------------------------
ParserError                               Traceback (most recent call last)
<ipython-input-16-721fca6bf50e> in <module>()
----> 1 tr.train(epochs=100)

~/miniconda3/envs/basepair/lib/python3.6/site-packages/gin_train/trainers.py in train(self, batch_size, epochs, early_stop_patience, num_workers, train_epoch_frac, valid_epoch_frac, train_samples_per_epoch, validation_samples, train_batch_sampler, tensorboard)
    130 
    131         # log metrics from the best epoch
--> 132         dfh = pd.read_csv(self.history_path)
    133         m = dict(dfh.iloc[dfh.val_loss.idxmin()])
    134         if self.cometml_experiment is not None:

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in parser_f(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, escapechar, comment, encoding, dialect, tupleize_cols, error_bad_lines, warn_bad_lines, skipfooter, doublequote, delim_whitespace, low_memory, memory_map, float_precision)
    676                     skip_blank_lines=skip_blank_lines)
    677 
--> 678         return _read(filepath_or_buffer, kwds)
    679 
    680     parser_f.__name__ = name

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in _read(filepath_or_buffer, kwds)
    444 
    445     try:
--> 446         data = parser.read(nrows)
    447     finally:
    448         parser.close()

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in read(self, nrows)
   1034                 raise ValueError('skipfooter not supported for iteration')
   1035 
-> 1036         ret = self._engine.read(nrows)
   1037 
   1038         # May alter columns / col_dict

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in read(self, nrows)
   1846     def read(self, nrows=None):
   1847         try:
-> 1848             data = self._reader.read(nrows)
   1849         except StopIteration:
   1850             if self._first_chunk:

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader.read()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._read_low_memory()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._read_rows()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._tokenize_rows()

pandas/_libs/parsers.pyx in pandas._libs.parsers.raise_parser_error()

ParserError: Error tokenizing data. C error: Expected 37 fields in line 9, saw 76
In [ ]:
#%debug
In [ ]:
eval_metrics = tr.evaluate(metric=None)  # metric=None -> uses the default head metrics
print(eval_metrics)

Get the importance scores

In [ ]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
In [ ]:
seq_model = tr.seq_model  # extract the seq_model
In [ ]:
from genomelake.extractors import FastaExtractor
In [ ]:
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
In [ ]:
it = valid.batch_iter(batch_size=1, shuffle=True)
In [ ]:
batch = next(it)
seq = batch['inputs']['seq']
In [ ]:
seq.shape
In [ ]:
seq_model.all_heads['Oct4'][0].use_bias
In [ ]:
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
In [ ]:
imp_scores.keys()
In [ ]:
imp_scores['Oct4/profile/wn'].shape
In [ ]:
preds = seq_model.predict_preact(seq)
In [ ]:
seq.shape

Pred and observed

In [ ]:
viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Obs", obs[task]),
                    (f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
                ] for task_idx, task in enumerate(tasks)]))

viz_dict = filter_tracks(viz_dict, [420, 575])
In [ ]:
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
        for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
        for feature in ['Imp profile', 'Obs']}


ylim = []
for k in viz_dict:
    f = k.split(" ", 1)[1]
    if "Imp" in f:
        ylim.append((fmin[f], fmax[f]))
    else:
        ylim.append((0, fmax[f]))
In [ ]:
fig = plot_tracks(viz_dict,
                  #seqlets=shifted_seqlets,
                  title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                  fig_height_per_track=2,
                  rotate_y=0,
                  fig_width=20,
                  ylim=ylim,
                  legend=False)
In [ ]:
train_set = train.load_all(batch_size=32, num_workers=5)
In [ ]:
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
In [ ]:
binsize=50
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
In [ ]:
fig, axes = plt.subplots(2, len(tasks), figsize=get_figsize(2/4*len(tasks), 2/len(tasks)))
for i, task in enumerate(train.tasks):
    preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
    log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
                                            axis=(1, 2),  dtype=np.float32))
    cc, p = spearmanr(preds_for_total, log_bias_total_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[0, i]
    matplotlib.rcParams.update({'font.size': 32})
    ax.scatter(preds_for_total, log_bias_total_counts, s=10, c="b", alpha=0.5, marker='o',
               label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_total")
    if i == 0:
        ax.set_ylabel("log_bias_total_counts")
    else:
        ax.set_ylabel("")
    ax.set_title(task)
    
    preds_for_local = np.ravel(np.sum(bin_counts(train_preds[f'{task}/profile'],
                                                binsize=binsize), axis=-1, dtype=np.float32))
    log_bias_local_counts = np.log10(np.ravel(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
                                                     binsize=binsize), axis=-1, dtype=np.float32)))
    cc, p = spearmanr(preds_for_local, log_bias_local_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[1, i]
    ax.scatter(preds_for_local, log_bias_local_counts, s=10, c="b", alpha=0.5, marker='o',
              label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_local")
    if i == 0:
        ax.set_ylabel("log_bias_local_counts")
    else:
        ax.set_ylabel("")