Goals

  • correct-for biases in ChIP-seq

how are model predictions correlated with the bias?

scatter-plot:

  • model predictions for total counts (before adding the bias term) vs bias total counts (log scale only the latter)
  • model predictions for local counts before adding the bias term (in say 50bp bins) {here you bin_counts the predictions} vs bias local counts (log scale only the latter)

basically we hope that our model output doesn't correlate with the bias

In [1]:
# Imports
from basepair.imports import *
import warnings
warnings.filterwarnings('ignore')
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
2019-02-28 00:54:17,865 [WARNING] git-lfs not installed
WARNING: Font Arial not installed and is required by
In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5, 7"
In [3]:
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml'
In [4]:
from basepair.datasets import get_StrandedProfile_datasets2

train, valid = get_StrandedProfile_datasets2(
    dataspec=dataspec_path,
    peak_width = 1000,
    seq_width=1000,
    include_metadata = False,
    taskname_first = True,  # so that the output labels will be "{task}/profile"
    exclude_chr = ['chrX', 'chrY'],
    profile_bias_pool_size=50)
In [5]:
valid = valid[0][1]
In [6]:
!cat {dataspec_path}
task_specs:
  Oct4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Oct4/idr-optimal-set.summit.bed.gz
  Sox2:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Sox2/idr-optimal-set.summit.bed.gz
  Nanog:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Nanog/idr-optimal-set.summit.bed.gz
  Klf4:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/counts.neg.bw
    peaks: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/Klf4/idr-optimal-set.summit.bed.gz

bias_specs:
  input:
    pos_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/patchcap/counts.pos.bw
    neg_counts: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/chip-nexus/patchcap/counts.neg.bw
    tasks:
      - Oct4
      - Sox2
      - Nanog
      - Klf4

fasta_file: /oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/mm10_no_alt_analysis_set_ENCODE.fasta
In [7]:
ds = DataSpec.load(dataspec_path)
In [8]:
tasks = list(ds.task_specs)
In [9]:
tasks
Out[9]:
['Oct4', 'Sox2', 'Nanog', 'Klf4']
In [10]:
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
from basepair.plot.evaluate import regression_eval
In [11]:
m = multihead_seq_model(tasks=tasks,
                        filters=64,
                        n_dil_layers=9,
                        conv1_kernel_size=25,tconv_kernel_size=25, 
                        b_loss_weight=0, c_loss_weight=10, p_loss_weight=1, 
                        use_bias=True,
                        lr=0.004, padding='same', batchnorm=False)
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-02-28 00:54:22,289 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-02-28 00:54:32,916 [WARNING] From /users/amr1/miniconda3/envs/basepair/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [12]:
output_dir='/tmp/exp/m3-wb'
In [13]:
!mkdir -p {output_dir}
In [14]:
!rm /tmp/exp/m3-wb/model.h5
rm: cannot remove '/tmp/exp/m3-wb/model.h5': No such file or directory
In [15]:
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
In [16]:
tr.train(epochs=100)
Epoch 1/100
355/355 [==============================] - 307s 866ms/step - loss: 3393.1239 - Oct4/profile_loss: 842.8376 - Oct4/counts_loss: 1.5186 - Sox2/profile_loss: 466.5282 - Sox2/counts_loss: 0.7830 - Nanog/profile_loss: 1111.6718 - Nanog/counts_loss: 1.8019 - Klf4/profile_loss: 920.1893 - Klf4/counts_loss: 1.0862 - val_loss: 3316.3708 - val_Oct4/profile_loss: 833.3349 - val_Oct4/counts_loss: 0.6739 - val_Sox2/profile_loss: 468.0766 - val_Sox2/counts_loss: 0.4639 - val_Nanog/profile_loss: 1083.4257 - val_Nanog/counts_loss: 0.9850 - val_Klf4/profile_loss: 903.3880 - val_Klf4/counts_loss: 0.6917
Epoch 2/100
355/355 [==============================] - 312s 879ms/step - loss: 3267.4557 - Oct4/profile_loss: 827.1931 - Oct4/counts_loss: 0.6378 - Sox2/profile_loss: 460.5485 - Sox2/counts_loss: 0.4627 - Nanog/profile_loss: 1048.4798 - Nanog/counts_loss: 1.0031 - Klf4/profile_loss: 903.5021 - Klf4/counts_loss: 0.6697 - val_loss: 3276.3369 - val_Oct4/profile_loss: 829.2569 - val_Oct4/counts_loss: 0.5710 - val_Sox2/profile_loss: 466.1402 - val_Sox2/counts_loss: 0.4153 - val_Nanog/profile_loss: 1056.7718 - val_Nanog/counts_loss: 0.9220 - val_Klf4/profile_loss: 898.7633 - val_Klf4/counts_loss: 0.6321
Epoch 3/100
355/355 [==============================] - 310s 873ms/step - loss: 3236.6962 - Oct4/profile_loss: 822.8839 - Oct4/counts_loss: 0.5583 - Sox2/profile_loss: 458.3800 - Sox2/counts_loss: 0.4076 - Nanog/profile_loss: 1031.0737 - Nanog/counts_loss: 0.9203 - Klf4/profile_loss: 899.3030 - Klf4/counts_loss: 0.6194 - val_loss: 3252.2771 - val_Oct4/profile_loss: 822.6902 - val_Oct4/counts_loss: 0.5231 - val_Sox2/profile_loss: 463.4735 - val_Sox2/counts_loss: 0.3689 - val_Nanog/profile_loss: 1047.4470 - val_Nanog/counts_loss: 0.8668 - val_Klf4/profile_loss: 895.2803 - val_Klf4/counts_loss: 0.5798
Epoch 4/100
355/355 [==============================] - 328s 925ms/step - loss: 3207.1324 - Oct4/profile_loss: 815.9638 - Oct4/counts_loss: 0.4928 - Sox2/profile_loss: 454.6378 - Sox2/counts_loss: 0.3586 - Nanog/profile_loss: 1018.7484 - Nanog/counts_loss: 0.8487 - Klf4/profile_loss: 894.9288 - Klf4/counts_loss: 0.5852 - val_loss: 3226.0633 - val_Oct4/profile_loss: 817.3654 - val_Oct4/counts_loss: 0.4534 - val_Sox2/profile_loss: 460.3721 - val_Sox2/counts_loss: 0.3251 - val_Nanog/profile_loss: 1035.1182 - val_Nanog/counts_loss: 0.7875 - val_Klf4/profile_loss: 892.0636 - val_Klf4/counts_loss: 0.5484
Epoch 5/100
355/355 [==============================] - 327s 921ms/step - loss: 3185.4393 - Oct4/profile_loss: 811.5903 - Oct4/counts_loss: 0.4370 - Sox2/profile_loss: 452.8927 - Sox2/counts_loss: 0.3124 - Nanog/profile_loss: 1009.7150 - Nanog/counts_loss: 0.7823 - Klf4/profile_loss: 890.3798 - Klf4/counts_loss: 0.5544 - val_loss: 3213.2752 - val_Oct4/profile_loss: 814.9148 - val_Oct4/counts_loss: 0.4123 - val_Sox2/profile_loss: 459.5208 - val_Sox2/counts_loss: 0.2859 - val_Nanog/profile_loss: 1031.3779 - val_Nanog/counts_loss: 0.7340 - val_Klf4/profile_loss: 887.9422 - val_Klf4/counts_loss: 0.5198
Epoch 6/100
355/355 [==============================] - 304s 857ms/step - loss: 3172.8330 - Oct4/profile_loss: 809.1857 - Oct4/counts_loss: 0.4011 - Sox2/profile_loss: 452.0725 - Sox2/counts_loss: 0.2810 - Nanog/profile_loss: 1004.0324 - Nanog/counts_loss: 0.7442 - Klf4/profile_loss: 887.9812 - Klf4/counts_loss: 0.5298 - val_loss: 3213.9731 - val_Oct4/profile_loss: 814.4909 - val_Oct4/counts_loss: 0.3608 - val_Sox2/profile_loss: 460.2455 - val_Sox2/counts_loss: 0.2524 - val_Nanog/profile_loss: 1035.2397 - val_Nanog/counts_loss: 0.6996 - val_Klf4/profile_loss: 885.9340 - val_Klf4/counts_loss: 0.4935
Epoch 7/100
355/355 [==============================] - 336s 947ms/step - loss: 3158.4185 - Oct4/profile_loss: 805.8290 - Oct4/counts_loss: 0.3713 - Sox2/profile_loss: 451.6069 - Sox2/counts_loss: 0.2597 - Nanog/profile_loss: 998.8330 - Nanog/counts_loss: 0.7192 - Klf4/profile_loss: 883.6748 - Klf4/counts_loss: 0.4973 - val_loss: 3200.6128 - val_Oct4/profile_loss: 812.5294 - val_Oct4/counts_loss: 0.3712 - val_Sox2/profile_loss: 459.3895 - val_Sox2/counts_loss: 0.2458 - val_Nanog/profile_loss: 1027.9904 - val_Nanog/counts_loss: 0.7054 - val_Klf4/profile_loss: 882.7277 - val_Klf4/counts_loss: 0.4751
Epoch 8/100
355/355 [==============================] - 371s 1s/step - loss: 3152.7198 - Oct4/profile_loss: 803.4860 - Oct4/counts_loss: 0.3499 - Sox2/profile_loss: 451.3736 - Sox2/counts_loss: 0.2447 - Nanog/profile_loss: 996.8339 - Nanog/counts_loss: 0.6995 - Klf4/profile_loss: 883.3001 - Klf4/counts_loss: 0.4785 - val_loss: 3187.1553 - val_Oct4/profile_loss: 809.2691 - val_Oct4/counts_loss: 0.3237 - val_Sox2/profile_loss: 457.9100 - val_Sox2/counts_loss: 0.2222 - val_Nanog/profile_loss: 1022.8658 - val_Nanog/counts_loss: 0.6479 - val_Klf4/profile_loss: 880.3688 - val_Klf4/counts_loss: 0.4803
Epoch 9/100
355/355 [==============================] - 373s 1s/step - loss: 3138.9192 - Oct4/profile_loss: 800.7040 - Oct4/counts_loss: 0.3296 - Sox2/profile_loss: 450.4392 - Sox2/counts_loss: 0.2314 - Nanog/profile_loss: 990.7243 - Nanog/counts_loss: 0.6712 - Klf4/profile_loss: 880.1480 - Klf4/counts_loss: 0.4581 - val_loss: 3189.0848 - val_Oct4/profile_loss: 810.4898 - val_Oct4/counts_loss: 0.3151 - val_Sox2/profile_loss: 458.5842 - val_Sox2/counts_loss: 0.2185 - val_Nanog/profile_loss: 1023.9766 - val_Nanog/counts_loss: 0.6686 - val_Klf4/profile_loss: 879.5616 - val_Klf4/counts_loss: 0.4450
Epoch 10/100
355/355 [==============================] - 359s 1s/step - loss: 3133.4933 - Oct4/profile_loss: 800.1688 - Oct4/counts_loss: 0.3163 - Sox2/profile_loss: 450.0886 - Sox2/counts_loss: 0.2200 - Nanog/profile_loss: 987.8874 - Nanog/counts_loss: 0.6401 - Klf4/profile_loss: 879.1191 - Klf4/counts_loss: 0.4466 - val_loss: 3187.9365 - val_Oct4/profile_loss: 810.6954 - val_Oct4/counts_loss: 0.3151 - val_Sox2/profile_loss: 458.1860 - val_Sox2/counts_loss: 0.2192 - val_Nanog/profile_loss: 1022.4499 - val_Nanog/counts_loss: 0.7301 - val_Klf4/profile_loss: 879.6750 - val_Klf4/counts_loss: 0.4286
Epoch 11/100
355/355 [==============================] - 365s 1s/step - loss: 3125.1986 - Oct4/profile_loss: 798.2635 - Oct4/counts_loss: 0.3053 - Sox2/profile_loss: 449.2469 - Sox2/counts_loss: 0.2132 - Nanog/profile_loss: 983.7885 - Nanog/counts_loss: 0.6165 - Klf4/profile_loss: 878.2231 - Klf4/counts_loss: 0.4327 - val_loss: 3192.5034 - val_Oct4/profile_loss: 811.0699 - val_Oct4/counts_loss: 0.3062 - val_Sox2/profile_loss: 458.5073 - val_Sox2/counts_loss: 0.2117 - val_Nanog/profile_loss: 1026.5345 - val_Nanog/counts_loss: 0.6008 - val_Klf4/profile_loss: 880.7759 - val_Klf4/counts_loss: 0.4428
Epoch 12/100
355/355 [==============================] - 364s 1s/step - loss: 3120.2278 - Oct4/profile_loss: 797.6870 - Oct4/counts_loss: 0.2992 - Sox2/profile_loss: 448.9931 - Sox2/counts_loss: 0.2091 - Nanog/profile_loss: 980.7733 - Nanog/counts_loss: 0.6027 - Klf4/profile_loss: 877.4574 - Klf4/counts_loss: 0.4207 - val_loss: 3183.8380 - val_Oct4/profile_loss: 808.8835 - val_Oct4/counts_loss: 0.2891 - val_Sox2/profile_loss: 457.3775 - val_Sox2/counts_loss: 0.2043 - val_Nanog/profile_loss: 1023.5060 - val_Nanog/counts_loss: 0.6194 - val_Klf4/profile_loss: 878.8006 - val_Klf4/counts_loss: 0.4142
Epoch 13/100
355/355 [==============================] - 325s 916ms/step - loss: 3118.0341 - Oct4/profile_loss: 797.5360 - Oct4/counts_loss: 0.2946 - Sox2/profile_loss: 448.8805 - Sox2/counts_loss: 0.2078 - Nanog/profile_loss: 979.0174 - Nanog/counts_loss: 0.5948 - Klf4/profile_loss: 877.3977 - Klf4/counts_loss: 0.4231 - val_loss: 3178.9318 - val_Oct4/profile_loss: 808.0237 - val_Oct4/counts_loss: 0.3100 - val_Sox2/profile_loss: 457.1798 - val_Sox2/counts_loss: 0.2284 - val_Nanog/profile_loss: 1018.4684 - val_Nanog/counts_loss: 0.6519 - val_Klf4/profile_loss: 878.4582 - val_Klf4/counts_loss: 0.4899
Epoch 14/100
355/355 [==============================] - 359s 1s/step - loss: 3112.2162 - Oct4/profile_loss: 796.3311 - Oct4/counts_loss: 0.2833 - Sox2/profile_loss: 448.5605 - Sox2/counts_loss: 0.2015 - Nanog/profile_loss: 976.3498 - Nanog/counts_loss: 0.5784 - Klf4/profile_loss: 876.2987 - Klf4/counts_loss: 0.4044 - val_loss: 3180.1342 - val_Oct4/profile_loss: 807.9870 - val_Oct4/counts_loss: 0.2915 - val_Sox2/profile_loss: 457.3991 - val_Sox2/counts_loss: 0.2042 - val_Nanog/profile_loss: 1022.4448 - val_Nanog/counts_loss: 0.5994 - val_Klf4/profile_loss: 877.2591 - val_Klf4/counts_loss: 0.4093
Epoch 15/100
355/355 [==============================] - 378s 1s/step - loss: 3112.2488 - Oct4/profile_loss: 796.5752 - Oct4/counts_loss: 0.2752 - Sox2/profile_loss: 448.8555 - Sox2/counts_loss: 0.1987 - Nanog/profile_loss: 976.2623 - Nanog/counts_loss: 0.5661 - Klf4/profile_loss: 876.1669 - Klf4/counts_loss: 0.3989 - val_loss: 3186.0508 - val_Oct4/profile_loss: 810.1154 - val_Oct4/counts_loss: 0.2764 - val_Sox2/profile_loss: 458.2777 - val_Sox2/counts_loss: 0.1989 - val_Nanog/profile_loss: 1023.1222 - val_Nanog/counts_loss: 0.5879 - val_Klf4/profile_loss: 879.9889 - val_Klf4/counts_loss: 0.3915
Epoch 16/100
355/355 [==============================] - 374s 1s/step - loss: 3105.1588 - Oct4/profile_loss: 795.4869 - Oct4/counts_loss: 0.2728 - Sox2/profile_loss: 448.0568 - Sox2/counts_loss: 0.1969 - Nanog/profile_loss: 971.6240 - Nanog/counts_loss: 0.5520 - Klf4/profile_loss: 875.8387 - Klf4/counts_loss: 0.3935 - val_loss: 3173.8931 - val_Oct4/profile_loss: 807.0639 - val_Oct4/counts_loss: 0.2757 - val_Sox2/profile_loss: 456.8668 - val_Sox2/counts_loss: 0.2053 - val_Nanog/profile_loss: 1018.0127 - val_Nanog/counts_loss: 0.5772 - val_Klf4/profile_loss: 877.5009 - val_Klf4/counts_loss: 0.3867
Epoch 17/100
355/355 [==============================] - 374s 1s/step - loss: 3106.7564 - Oct4/profile_loss: 795.7710 - Oct4/counts_loss: 0.2699 - Sox2/profile_loss: 448.1937 - Sox2/counts_loss: 0.1958 - Nanog/profile_loss: 972.7874 - Nanog/counts_loss: 0.5447 - Klf4/profile_loss: 875.9224 - Klf4/counts_loss: 0.3978 - val_loss: 3166.2019 - val_Oct4/profile_loss: 805.8648 - val_Oct4/counts_loss: 0.2679 - val_Sox2/profile_loss: 455.8197 - val_Sox2/counts_loss: 0.2043 - val_Nanog/profile_loss: 1014.6580 - val_Nanog/counts_loss: 0.6162 - val_Klf4/profile_loss: 875.2123 - val_Klf4/counts_loss: 0.3763
Epoch 18/100
355/355 [==============================] - 339s 956ms/step - loss: 3101.7218 - Oct4/profile_loss: 795.1697 - Oct4/counts_loss: 0.2663 - Sox2/profile_loss: 447.8593 - Sox2/counts_loss: 0.1938 - Nanog/profile_loss: 969.9122 - Nanog/counts_loss: 0.5307 - Klf4/profile_loss: 874.9192 - Klf4/counts_loss: 0.3954 - val_loss: 3182.0221 - val_Oct4/profile_loss: 809.1289 - val_Oct4/counts_loss: 0.2780 - val_Sox2/profile_loss: 457.7811 - val_Sox2/counts_loss: 0.2138 - val_Nanog/profile_loss: 1020.8056 - val_Nanog/counts_loss: 0.5998 - val_Klf4/profile_loss: 879.5068 - val_Klf4/counts_loss: 0.3884
Epoch 19/100
355/355 [==============================] - 303s 854ms/step - loss: 3100.3769 - Oct4/profile_loss: 794.5896 - Oct4/counts_loss: 0.2585 - Sox2/profile_loss: 448.0376 - Sox2/counts_loss: 0.1910 - Nanog/profile_loss: 969.2384 - Nanog/counts_loss: 0.5270 - Klf4/profile_loss: 874.8899 - Klf4/counts_loss: 0.3857 - val_loss: 3178.2824 - val_Oct4/profile_loss: 809.1879 - val_Oct4/counts_loss: 0.2674 - val_Sox2/profile_loss: 457.7574 - val_Sox2/counts_loss: 0.1971 - val_Nanog/profile_loss: 1019.5547 - val_Nanog/counts_loss: 0.5617 - val_Klf4/profile_loss: 877.6648 - val_Klf4/counts_loss: 0.3856
Epoch 20/100
355/355 [==============================] - 293s 825ms/step - loss: 3096.5185 - Oct4/profile_loss: 793.9582 - Oct4/counts_loss: 0.2533 - Sox2/profile_loss: 447.5863 - Sox2/counts_loss: 0.1881 - Nanog/profile_loss: 967.4552 - Nanog/counts_loss: 0.5145 - Klf4/profile_loss: 874.2049 - Klf4/counts_loss: 0.3755 - val_loss: 3175.1027 - val_Oct4/profile_loss: 806.2642 - val_Oct4/counts_loss: 0.2752 - val_Sox2/profile_loss: 455.8690 - val_Sox2/counts_loss: 0.2253 - val_Nanog/profile_loss: 1020.4275 - val_Nanog/counts_loss: 0.7340 - val_Klf4/profile_loss: 876.1223 - val_Klf4/counts_loss: 0.4075
Epoch 21/100
355/355 [==============================] - 293s 824ms/step - loss: 3096.2840 - Oct4/profile_loss: 794.1508 - Oct4/counts_loss: 0.2458 - Sox2/profile_loss: 447.6994 - Sox2/counts_loss: 0.1834 - Nanog/profile_loss: 967.2462 - Nanog/counts_loss: 0.5037 - Klf4/profile_loss: 874.1680 - Klf4/counts_loss: 0.3690 - val_loss: 3180.7173 - val_Oct4/profile_loss: 807.0079 - val_Oct4/counts_loss: 0.2608 - val_Sox2/profile_loss: 457.2582 - val_Sox2/counts_loss: 0.2019 - val_Nanog/profile_loss: 1021.7963 - val_Nanog/counts_loss: 0.6567 - val_Klf4/profile_loss: 879.6097 - val_Klf4/counts_loss: 0.3851
---------------------------------------------------------------------------
ParserError                               Traceback (most recent call last)
<ipython-input-16-721fca6bf50e> in <module>()
----> 1 tr.train(epochs=100)

~/miniconda3/envs/basepair/lib/python3.6/site-packages/gin_train/trainers.py in train(self, batch_size, epochs, early_stop_patience, num_workers, train_epoch_frac, valid_epoch_frac, train_samples_per_epoch, validation_samples, train_batch_sampler, tensorboard)
    130 
    131         # log metrics from the best epoch
--> 132         dfh = pd.read_csv(self.history_path)
    133         m = dict(dfh.iloc[dfh.val_loss.idxmin()])
    134         if self.cometml_experiment is not None:

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in parser_f(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, escapechar, comment, encoding, dialect, tupleize_cols, error_bad_lines, warn_bad_lines, skipfooter, doublequote, delim_whitespace, low_memory, memory_map, float_precision)
    676                     skip_blank_lines=skip_blank_lines)
    677 
--> 678         return _read(filepath_or_buffer, kwds)
    679 
    680     parser_f.__name__ = name

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in _read(filepath_or_buffer, kwds)
    444 
    445     try:
--> 446         data = parser.read(nrows)
    447     finally:
    448         parser.close()

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in read(self, nrows)
   1034                 raise ValueError('skipfooter not supported for iteration')
   1035 
-> 1036         ret = self._engine.read(nrows)
   1037 
   1038         # May alter columns / col_dict

~/miniconda3/envs/basepair/lib/python3.6/site-packages/pandas/io/parsers.py in read(self, nrows)
   1846     def read(self, nrows=None):
   1847         try:
-> 1848             data = self._reader.read(nrows)
   1849         except StopIteration:
   1850             if self._first_chunk:

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader.read()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._read_low_memory()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._read_rows()

pandas/_libs/parsers.pyx in pandas._libs.parsers.TextReader._tokenize_rows()

pandas/_libs/parsers.pyx in pandas._libs.parsers.raise_parser_error()

ParserError: Error tokenizing data. C error: Expected 37 fields in line 9, saw 76
In [ ]:
#%debug
In [ ]:
eval_metrics = tr.evaluate(metric=None)  # metric=None -> uses the default head metrics
print(eval_metrics)

Get the importance scores

In [ ]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
In [ ]:
seq_model = tr.seq_model  # extract the seq_model
In [ ]:
from genomelake.extractors import FastaExtractor
In [ ]:
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
In [ ]:
it = valid.batch_iter(batch_size=1, shuffle=True)
In [ ]:
batch = next(it)
seq = batch['inputs']['seq']
In [ ]:
seq.shape
In [ ]:
seq_model.all_heads['Oct4'][0].use_bias
In [ ]:
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
In [ ]:
imp_scores.keys()
In [ ]:
imp_scores['Oct4/profile/wn'].shape
In [ ]:
preds = seq_model.predict_preact(seq)
In [ ]:
seq.shape

Pred and observed

In [ ]:
viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Obs", obs[task]),
                    (f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
                ] for task_idx, task in enumerate(tasks)]))

viz_dict = filter_tracks(viz_dict, [420, 575])
In [ ]:
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
        for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
        for feature in ['Imp profile', 'Obs']}


ylim = []
for k in viz_dict:
    f = k.split(" ", 1)[1]
    if "Imp" in f:
        ylim.append((fmin[f], fmax[f]))
    else:
        ylim.append((0, fmax[f]))
In [ ]:
fig = plot_tracks(viz_dict,
                  #seqlets=shifted_seqlets,
                  title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                  fig_height_per_track=2,
                  rotate_y=0,
                  fig_width=20,
                  ylim=ylim,
                  legend=False)
In [ ]:
train_set = train.load_all(batch_size=32, num_workers=5)
In [ ]:
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
In [ ]:
binsize=50
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
In [ ]:
fig, axes = plt.subplots(2, len(tasks), figsize=get_figsize(2/4*len(tasks), 2/len(tasks)))
for i, task in enumerate(train.tasks):
    preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
    log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
                                            axis=(1, 2),  dtype=np.float32))
    cc, p = spearmanr(preds_for_total, log_bias_total_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[0, i]
    matplotlib.rcParams.update({'font.size': 32})
    ax.scatter(preds_for_total, log_bias_total_counts, s=10, c="b", alpha=0.5, marker='o',
               label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_total")
    if i == 0:
        ax.set_ylabel("log_bias_total_counts")
    else:
        ax.set_ylabel("")
    ax.set_title(task)
    
    preds_for_local = np.ravel(np.sum(bin_counts(train_preds[f'{task}/profile'],
                                                binsize=binsize), axis=-1, dtype=np.float32))
    log_bias_local_counts = np.log10(np.ravel(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
                                                     binsize=binsize), axis=-1, dtype=np.float32)))
    cc, p = spearmanr(preds_for_local, log_bias_local_counts)
    cc = "{0:.2f}".format(cc)
    ax = axes[1, i]
    ax.scatter(preds_for_local, log_bias_local_counts, s=10, c="b", alpha=0.5, marker='o',
              label=f'Rs={cc}')
    ax.legend()
    ax.set_xlabel("preds_for_local")
    if i == 0:
        ax.set_ylabel("log_bias_local_counts")
    else:
        ax.set_ylabel("")