Interpretation

  • visualize the imporance scores for the highest predicted peaks

Conclusions

  • to have the nice footprint, both the positive and negative gradients have to agree

Expected motifs

In [11]:
motifs = {"motif1": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif1.motif",
          "motif2": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif2.motif"}
In [12]:
from basepair.motif.homer import load_motif, read_motif_hits
pwm_list = [load_motif(fname) for k,fname in motifs.items()]
In [15]:
for i,pwm in enumerate(pwm_list):
    pwm.plotPWMInfo((5,1.5))
    plt.title(f"motif{i+1}")
In [60]:
from basepair.data import seq_inp_exo_out
from basepair.config import get_data_dir
from basepair.math import softmax
from keras.models import load_model
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from concise.utils.plot import seqlogo_fig, seqlogo
In [4]:
%env CUDA_VISIBLE_DEVICES=3
env: CUDA_VISIBLE_DEVICES=3

Get the data

In [5]:
ddir = get_data_dir()
In [94]:
train, test = seq_inp_exo_out()
100%|██████████| 9396/9396 [00:47<00:00, 197.67it/s]
In [143]:
labels = train[2].chr + ":" + train[2].start.astype(str) + "-" + train[2].end.astype(str)
In [113]:
max_counts_pos = pd.Series(np.max(test[1][:,:,0], axis=-1))
max_counts_neg = pd.Series(np.max(test[1][:,:,1], axis=-1))
In [114]:
(max_counts_pos + max_counts_neg).plot(kind='hist')
Out[114]:
<matplotlib.axes._subplots.AxesSubplot at 0x7ff3ae6f1f60>

Load the model

In [13]:
ckp_file = f"{ddir}/processed/chipnexus/exp/models/resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25.h5"
In [14]:
import keras
In [15]:
model = load_model(ckp_file)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-04-25 01:04:40,112 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-04-25 01:04:49,072 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

Define gradient functions

In [120]:
# Define the gradient * input function w.r.t. to maximum output
out = model.outputs[0]
inp = model.inputs[0]
pos_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,1], axis=-1), inp) * inp])
pos_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,1], axis=-1), inp) * inp])

Top 10 sequnces based on counts

In [154]:
top10_idx = (max_counts_pos + max_counts_neg).sort_values(ascending=False).index[:10]
In [155]:
y_true = test[1]
In [156]:
y_pred = softmax(model.predict(test[0]))
In [157]:
for i, idx in enumerate(top10_idx):
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Predicted")
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Observed")

Max gradient*input

In [158]:
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
In [159]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [160]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));
    plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))

Average gradient*input

In [161]:
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
In [162]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [163]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));

Randomly selected 10 sequnces

In [136]:
# Random 10 idx
top10_idx = pd.Series(np.arange(len(test[0]))).sample(10).values
In [137]:
top10_idx
Out[137]:
array([1168, 1843, 1908, 1733, 1318,  229, 1597, 1398, 1345,  715])
In [138]:
y_true = test[1]
In [139]:
y_pred = softmax(model.predict(test[0]))
In [140]:
for i, idx in enumerate(top10_idx):
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Predicted")
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Observed")

Max gradient*input

In [141]:
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
In [149]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [150]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));
    plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))

Average gradient*input

In [151]:
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
In [152]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [153]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));

Filter interpretation

In [108]:
w = np.swapaxes(w, 1,2)
In [109]:
w = np.swapaxes(w, 0, 1)
In [114]:
for i in range(len(w)):
    seqlogo_fig(w[i], figsize=(5,2));
/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/matplotlib/pyplot.py:537: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)