motifs = {"motif1": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif1.motif",
"motif2": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif2.motif"}
from basepair.motif.homer import load_motif, read_motif_hits
pwm_list = [load_motif(fname) for k,fname in motifs.items()]
for i,pwm in enumerate(pwm_list):
pwm.plotPWMInfo((5,1.5))
plt.title(f"motif{i+1}")
from basepair.data import seq_inp_exo_out
from basepair.config import get_data_dir
from basepair.math import softmax
from keras.models import load_model
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from concise.utils.plot import seqlogo_fig, seqlogo
%env CUDA_VISIBLE_DEVICES=3
ddir = get_data_dir()
train, test = seq_inp_exo_out()
labels = train[2].chr + ":" + train[2].start.astype(str) + "-" + train[2].end.astype(str)
max_counts_pos = pd.Series(np.max(test[1][:,:,0], axis=-1))
max_counts_neg = pd.Series(np.max(test[1][:,:,1], axis=-1))
(max_counts_pos + max_counts_neg).plot(kind='hist')
ckp_file = f"{ddir}/processed/chipnexus/exp/models/resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25.h5"
import keras
model = load_model(ckp_file)
# Define the gradient * input function w.r.t. to maximum output
out = model.outputs[0]
inp = model.inputs[0]
pos_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,1], axis=-1), inp) * inp])
pos_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,1], axis=-1), inp) * inp])
top10_idx = (max_counts_pos + max_counts_neg).sort_values(ascending=False).index[:10]
y_true = test[1]
y_pred = softmax(model.predict(test[0]))
for i, idx in enumerate(top10_idx):
plt.figure(figsize=(10,2))
plt.subplot(121)
plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
plt.legend();
if i==0:
plt.title("Predicted")
plt.subplot(122)
plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
plt.legend();
if i==0:
plt.title("Observed")
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
# Random 10 idx
top10_idx = pd.Series(np.arange(len(test[0]))).sample(10).values
top10_idx
y_true = test[1]
y_pred = softmax(model.predict(test[0]))
for i, idx in enumerate(top10_idx):
plt.figure(figsize=(10,2))
plt.subplot(121)
plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
plt.legend();
if i==0:
plt.title("Predicted")
plt.subplot(122)
plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
plt.legend();
if i==0:
plt.title("Observed")
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
w = np.swapaxes(w, 1,2)
w = np.swapaxes(w, 0, 1)
for i in range(len(w)):
seqlogo_fig(w[i], figsize=(5,2));