from basepair.data import seq_inp_exo_out
from basepair.config import get_data_dir
from basepair.math import softmax
from keras.models import load_model
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from concise.utils.plot import seqlogo_fig, seqlogo
%env CUDA_VISIBLE_DEVICES=3
ddir = get_data_dir()
motifs = {"motif1": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif1.motif",
"motif2": f"{ddir}/processed/chipnexus/motifs/sox2/homer_200bp/de-novo/motif2.motif"}
from basepair.motif.homer import load_motif, read_motif_hits
pwm_list = [load_motif(fname) for k,fname in motifs.items()]
for i,pwm in enumerate(pwm_list):
pwm.plotPWMInfo((5,1.5))
plt.title(f"motif{i+1}")
train, test = seq_inp_exo_out()
labels = test[2].chr + ":" + test[2].start.astype(str) + "-" + test[2].end.astype(str)
max_counts_pos = pd.Series(np.max(test[1][:,:,0], axis=-1))
max_counts_neg = pd.Series(np.max(test[1][:,:,1], axis=-1))
(max_counts_pos + max_counts_neg).plot(kind='hist')
ckp_file = f"{ddir}/processed/chipnexus/exp/models/resnest_allconnect_nconv=7_filters=32_lr=0.004_dilated=True,out=25.h5"
import keras
model = load_model(ckp_file)
# Define the gradient * input function w.r.t. to maximum output
out = model.outputs[0]
inp = model.inputs[0]
pos_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_avg = K.function([inp], [K.gradients(K.mean(out[:,:,1], axis=-1), inp) * inp])
pos_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,0], axis=-1), inp) * inp])
neg_strand_ginp_max = K.function([inp], [K.gradients(K.max(out[:,:,1], axis=-1), inp) * inp])
# Define the gradient * input function w.r.t. to maximum output
pred_bottleneck = K.function([inp], [model.get_layer("add_180").output])
top10_idx = (max_counts_pos + max_counts_neg).sort_values(ascending=False).index[:10]
y_true = test[1]
y_pred = softmax(model.predict(test[0]))
bottleneck = pred_bottleneck([test[0]])[0]
# TODO - run PCA
from sklearn.decomposition import PCA
pca = PCA(10)
pca.fit(bottleneck.reshape((-1, 32)))
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(pca.explained_variance_ratio_, ".-")
plt.xlabel("# PC")
plt.ylabel("Var. explained")
plt.subplot(122)
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("# PC")
plt.ylabel("Total var. explained");
plt.imshow(pca.components_)
plt.xlabel("Component weights")
plt.ylabel("Component number")
plt.plot(pca.components_[0])
plt.figure(figsize=(20, 1.5))
for i in range(8):
plt.subplot(1,8,i+1)
out = (pca.components_[i].reshape((1,1,1,-1)) * model.get_layer("conv2d_transpose_53").get_weights()[0]).sum(axis=-1)[:,0]
plt.plot(out[:,0])
plt.plot(out[:,1])
plt.title(f"Deconv filter w/ PC{i}")
plt.tight_layout()
plt.figure(figsize=(20, 6))
w=model.get_layer("conv2d_transpose_53").get_weights()[0]
for i in range(32):
plt.subplot(4,8,i+1)
out = w[:,0,:,i]
plt.plot(out[:,0])
plt.plot(out[:,1])
plt.title(f"deconv filter {i}")
plt.tight_layout()
from sklearn.manifold import TSNE, MDS
tsne = TSNE()
tsne = tsne.fit(bottleneck.reshape((-1, 32)), )
plt.scatter(tsne.embedding_[:,0], tsne.embedding_[:,1])
a=1
mds = MDS()
mds = mds.fit(bottleneck.reshape((-1, 32)))
plt.scatter(mds.embedding_[:,0], mds.embedding_[:,1])
import umap
um = umap.UMAP(n_neighbors=5,
min_dist=0.3)
um.fit(mds.embedding_[:,0], mds.embedding_[:,1])
plt.scatter(um.embedding_[:,0], um.embedding_[:,1])
model.get_layer("conv2d_transpose_53").get_weights()[0].shape
plt.plot()
# TODO - visualize the profile of the first PC
bottleneck.reshape((-1, 32))
for i, idx in enumerate(top10_idx):
plt.figure(figsize=(20,4))
plt.imshow(bottleneck[idx].T,aspect='auto' )
idx = top10_idx[2]
idx
bottleneck
for i, idx in enumerate(top10_idx):
plt.figure(figsize=(10,2))
plt.subplot(121)
plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
plt.legend();
if i==0:
plt.title("Predicted")
plt.subplot(122)
plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
plt.legend();
if i==0:
plt.title("Observed")
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
bottleneck = pred_bottleneck([test[0][top10_idx]])[0]
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
for i in range(len(top10_idx)):
fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(5, 1, sharex=True, figsize=(20,6))
ax0.imshow(bottleneck_pca[i].T,aspect='auto')
ax0.set_ylabel("pca\nBottleneck)")
ax0.set_title('{}'.format(labels.iloc[top10_idx[i]]))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
#plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
bottleneck = pred_bottleneck([test[0][top10_idx]])[0]
bottleneck_pca = pca.transform(bottleneck.reshape((-1, 32))).reshape((len(bottleneck), -1, 10))
for i in range(len(top10_idx)):
fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(5, 1, sharex=True, figsize=(20,6))
ax0.imshow(bottleneck_pca[i].T,aspect='auto')
ax0.set_ylabel("pca(Bottleneck act)")
ax0.set_title('{}'.format(labels.iloc[top10_idx[i]]))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
#plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))
# Random 10 idx
top10_idx = pd.Series(np.arange(len(test[0]))).sample(10).values
top10_idx
y_true = test[1]
y_pred = softmax(model.predict(test[0]))
for i, idx in enumerate(top10_idx):
plt.figure(figsize=(10,2))
plt.subplot(121)
plt.plot(y_pred[idx,:,0], label='pos,m={}'.format(np.argmax(y_pred[idx,:,0])))
plt.plot(y_pred[idx,:,1], label='neg,m={}'.format(np.argmax(y_pred[idx,:,1])))
plt.legend();
if i==0:
plt.title("Predicted")
plt.subplot(122)
plt.plot(y_true[idx,:,0], label='pos,m={}'.format(np.argmax(y_true[idx,:,0])))
plt.plot(y_true[idx,:,1], label='neg,m={}'.format(np.argmax(y_true[idx,:,1])))
plt.legend();
if i==0:
plt.title("Observed")
ginp_pos = pos_strand_ginp_max([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_max([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
plt.suptitle('{}'.format(labels.iloc[top10_idx[i]]))
ginp_pos = pos_strand_ginp_avg([test[0][top10_idx]])[0][0]
ginp_neg = neg_strand_ginp_avg([test[0][top10_idx]])[0][0]
y_true = test[1][top10_idx]
y_pred = softmax(model.predict(test[0][top10_idx]))
for i in range(len(top10_idx)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(20,6))
ax1.plot(np.arange(1,202), y_true[i,:,0], label="pos")#
ax1.plot(np.arange(1,202), y_true[i,:,1], label="neg")
ax1.set_ylabel("Observed\ncounts")
ax1.legend()
ax2.plot(np.arange(1,202), y_pred[i,:,0], label="pos")#
ax2.plot(np.arange(1,202), y_pred[i,:,1], label="neg")
ax2.set_ylabel("Predicted\n")
ax2.legend()
ax3.set_ylabel("Positive strand")
seqlogo(ginp_pos[i], ax=ax3);
ax4.set_ylabel("Negative strand")
seqlogo(ginp_neg[i], ax=ax4);
x_range = [1, 201]
ax4.set_xticks(list(range(0, 201, 5)));
w = np.swapaxes(w, 1,2)
w = np.swapaxes(w, 0, 1)
for i in range(len(w)):
seqlogo_fig(w[i], figsize=(5,2));