Interpretation

  • visualize the imporance scores for the highest predicted peaks

Conclusions

  • to have the nice footprint, both the positive and negative gradients have to agree

Expected motifs

In [7]:
from basepair.data import seq_inp_exo_out
from basepair.config import get_data_dir
from basepair.math import softmax
from keras.models import load_model
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from concise.utils.plot import seqlogo_fig, seqlogo
In [8]:
%env CUDA_VISIBLE_DEVICES=3
env: CUDA_VISIBLE_DEVICES=3
In [9]:
ddir = get_data_dir()
In [10]:
motifs = {"motif1": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif1.motif",
          "motif2": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif2.motif"}
In [11]:
from basepair.motif.homer import load_motif, read_motif_hits
pwm_list = [load_motif(fname) for k,fname in motifs.items()]
In [12]:
for i,pwm in enumerate(pwm_list):
    pwm.plotPWMInfo((5,1.5))
    plt.title(f"motif{i+1}")

Get the data

In [13]:
train, test = seq_inp_exo_out()
100%|██████████| 9396/9396 [00:02<00:00, 3375.14it/s]
In [91]:
labels = test[2].chr + ":" + test[2].start.astype(str) + "-" + test[2].end.astype(str)
In [15]:
max_counts_pos = pd.Series(np.max(test[1][:,:,0], axis=-1))
max_counts_neg = pd.Series(np.max(test[1][:,:,1], axis=-1))
In [16]:
(max_counts_pos + max_counts_neg).plot(kind='hist')
Out[16]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f643ce54ba8>

Load the model

In [48]:
ckp_file = f"{ddir}/processed/chipnexus/exp/models/resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25.h5"
In [49]:
import keras
In [50]:
model = load_model(ckp_file)

Define gradient functions

In [51]:
# Define the gradient * input function w.r.t. to maximum output
out = model.outputs[0]
inp = model.inputs[0]
pos_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,1], axis=-1), inp) * inp])
pos_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,1], axis=-1), inp) * inp])
In [55]:
# Define the gradient * input function w.r.t. to maximum output
pred_bottleneck = K.function([inp], [model.get_layer("add_180").output])

Top 10 sequnces based on counts

In [56]:
top10_idx = (max_counts_pos + max_counts_neg).sort_values(ascending=False).index[:10]
In [57]:
y_true = test[1]
In [58]:
y_pred = softmax(model.predict(test[0]))
In [179]:
bottleneck = pred_bottleneck([test[0]])[0]
In [62]:
# TODO - run PCA
In [68]:
from sklearn.decomposition import PCA
In [108]:
pca = PCA(10)
In [109]:
pca.fit(bottleneck.reshape((-1, 32)))
Out[109]:
PCA(copy=True, iterated_power='auto', n_components=10, random_state=None,
  svd_solver='auto', tol=0.0, whiten=False)
In [164]:
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
In [110]:
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(pca.explained_variance_ratio_, ".-")
plt.xlabel("# PC")
plt.ylabel("Var. explained")
plt.subplot(122)
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("# PC")
plt.ylabel("Total var. explained");
In [111]:
plt.imshow(pca.components_)
plt.xlabel("Component weights")
plt.ylabel("Component number")
Out[111]:
Text(0,0.5,'Component number')
In [112]:
plt.plot(pca.components_[0])
Out[112]:
[<matplotlib.lines.Line2D at 0x7f61bad8c588>]
In [172]:
plt.figure(figsize=(20, 1.5))
for i in range(8):
    plt.subplot(1,8,i+1)
    out = (pca.components_[i].reshape((1,1,1,-1)) * model.get_layer("conv2d_transpose_53").get_weights()[0]).sum(axis=-1)[:,0]
    plt.plot(out[:,0])
    plt.plot(out[:,1])
    plt.title(f"Deconv filter w/ PC{i}")
    plt.tight_layout()
In [171]:
plt.figure(figsize=(20, 6))
w=model.get_layer("conv2d_transpose_53").get_weights()[0]
for i in range(32):
    plt.subplot(4,8,i+1)
    out = w[:,0,:,i]
    plt.plot(out[:,0])
    plt.plot(out[:,1])
    plt.title(f"deconv filter {i}")
    plt.tight_layout()
In [178]:
from sklearn.manifold import TSNE, MDS
In [180]:
tsne = TSNE()
In [ ]:
tsne = tsne.fit(bottleneck.reshape((-1, 32)), )
In [ ]:
plt.scatter(tsne.embedding_[:,0], tsne.embedding_[:,1])
In [ ]:
a=1
In [180]:
mds = MDS()
In [ ]:
mds = mds.fit(bottleneck.reshape((-1, 32)))
In [ ]:
plt.scatter(mds.embedding_[:,0], mds.embedding_[:,1])
In [ ]:
import umap
In [ ]:
um = umap.UMAP(n_neighbors=5,
               min_dist=0.3)
In [ ]:
um.fit(mds.embedding_[:,0], mds.embedding_[:,1])
In [ ]:
plt.scatter(um.embedding_[:,0], um.embedding_[:,1])
In [129]:
model.get_layer("conv2d_transpose_53").get_weights()[0].shape
Out[129]:
(25, 1, 2, 32)
In [ ]:
plt.plot()
In [80]:
# TODO - visualize the profile of the first PC
In [67]:
bottleneck.reshape((-1, 32))
Out[67]:
True
In [61]:
for i, idx in enumerate(top10_idx):
    plt.figure(figsize=(20,4))
    plt.imshow(bottleneck[idx].T,aspect='auto' )
In [32]:
idx = top10_idx[2]
In [33]:
idx
Out[33]:
1943
In [ ]:
bottleneck
In [157]:
for i, idx in enumerate(top10_idx):
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Predicted")
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Observed")

Max gradient*input

In [174]:
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
In [175]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
bottleneck = pred_bottleneck([test[0][top10_idx]])[0]
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
In [177]:
for i in range(len(top10_idx)):
    fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(5, 1, sharex=True, figsize=(20,6))
    ax0.imshow(bottleneck_pca[i].T,aspect='auto')
    ax0.set_ylabel("pca\nBottleneck)")
    ax0.set_title('{}'.format(labels.iloc[top10_idx[i]]))
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));
    #plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))

Average gradient*input

In [168]:
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
In [169]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
bottleneck = pred_bottleneck([test[0][top10_idx]])[0]
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
In [170]:
for i in range(len(top10_idx)):
    fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(5, 1, sharex=True, figsize=(20,6))
    ax0.imshow(bottleneck_pca[i].T,aspect='auto')
    ax0.set_ylabel("pca(Bottleneck act)")
    ax0.set_title('{}'.format(labels.iloc[top10_idx[i]]))
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));
    #plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))

Randomly selected 10 sequnces

In [136]:
# Random 10 idx
top10_idx = pd.Series(np.arange(len(test[0]))).sample(10).values
In [137]:
top10_idx
Out[137]:
array([1168, 1843, 1908, 1733, 1318,  229, 1597, 1398, 1345,  715])
In [138]:
y_true = test[1]
In [139]:
y_pred = softmax(model.predict(test[0]))
In [140]:
for i, idx in enumerate(top10_idx):
    plt.figure(figsize=(10,2))
    plt.subplot(121)
    plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
    plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Predicted")
    plt.subplot(122)
    plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
    plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
    plt.legend();
    if i==0:
        plt.title("Observed")

Max gradient*input

In [141]:
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
In [149]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [150]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));
    plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))

Average gradient*input

In [151]:
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
In [152]:
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
In [153]:
for i in range(len(top10_idx)):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
    
    ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
    ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
    ax1.set_ylabel("Observed\ncounts")
    ax1.legend()
    ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
    ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
    ax2.set_ylabel("Predicted\n")
    ax2.legend()
    ax3.set_ylabel("Positive strand")
    seqlogo(ginp_pos[i], ax=ax3);
    ax4.set_ylabel("Negative strand")
    seqlogo(ginp_neg[i], ax=ax4);
    x_range = [1, 201]
    ax4.set_xticks(list(range(0, 201, 5)));

Filter interpretation

In [108]:
w = np.swapaxes(w, 1,2)
In [109]:
w = np.swapaxes(w, 0, 1)
In [114]:
for i in range(len(w)):
    seqlogo_fig(w[i], figsize=(5,2));
/users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/matplotlib/pyplot.py:537: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)