import basepair
from keras.models import Model, load_model
from basepair.losses import twochannel_multinomial_nll
# Use gpus 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
model = load_model("model.h5", custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
from basepair.utils import read_pkl
train,valid,test = read_pkl("/users/avsec/workspace/basepair-workflow/models/0/data.pkl")
preproc = read_pkl("/users/avsec/workspace/basepair-workflow/models/0/preprocessor.pkl")
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
task_names = ["profile/Oct4", "profile/Sox2", "profile/Klf4", "profile/Nanog",
"counts/Oct4", "counts/Sox2", "counts/Klf4", "counts/Nanog"]
import keras.backend as K
inp = model.inputs[0]
fn_pos = {}
fn_neg = {}
for task_id, task_name in enumerate(task_names):
if "counts" in task_name:
fn_pos[task_name] = K.function([inp], K.gradients(model.outputs[task_id][:, 0], inp))
fn_neg[task_name] = K.function([inp], K.gradients(model.outputs[task_id][:, 1], inp))
else:
fn_pos[task_name] = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(
model.outputs[task_id][:, :, 0])) * model.outputs[task_id][:, :, 0], axis=-1), inp))
fn_neg[task_name] = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(
model.outputs[task_id][:, :, 1])) * model.outputs[task_id][:, :, 1], axis=-1), inp))
import numpy as np
from basepair.data import numpy_minibatch
grads_pos = {}
grads_neg = {}
for task_name in task_names:
grads_pos[task_name] = np.concatenate([np.array(fn_pos[task_name]([batch])).squeeze()
for batch in numpy_minibatch(valid[0], 512)])
grads_neg[task_name] = np.concatenate([np.array(fn_neg[task_name]([batch])).squeeze()
for batch in numpy_minibatch(valid[0], 512)])
# Setup different scores
hyp_scores = {}
scores = {}
for task_name in task_names:
hyp_scores[task_name] = grads_pos[task_name] + grads_neg[task_name]
hyp_scores[task_name] = hyp_scores[task_name] - hyp_scores[task_name].mean(-1, keepdims=True)
scores[task_name] = hyp_scores[task_name] * valid[0]
onehot_data = valid[0]
task_to_scores = scores
task_to_hyp_scores = hyp_scores
from concise.utils.plot import seqlogo_fig, seqlogo
import matplotlib.pyplot as plt
fig, (ax0, ax1)= plt.subplots(2, 1, sharex=True, figsize=(20, 6))
ax0.set_title("scores")
seqlogo(scores["profile/Oct4"][0], ax=ax0)
ax1.set_title("hyp_scores")
seqlogo(hyp_scores["profile/Oct4"][0], ax=ax1)
tfs = ["Oct4", "Sox2", "Klf4", "Nanog"]
def bpnet_predict(seqs, tasks=tfs, preproc=preproc):
from basepair.math import softmax
from kipoi.data_utils import get_dataset_item
preds = model.predict(seqs, batch_size=512)
# Use softmax for the first prediction
preds[:len(tasks)] = [softmax(p) for p in preds[:len(tasks)]]
preds_dict = dict(profile=preds[:len(tasks)],
counts=preds[len(tasks):])
# Dev
def pred2scale_strands(preds):
"""Compute the scaling factor for the profile in order to
obtain the absolute counts
"""
return {task: np.exp(preproc.objects[f'profile/{task}'].steps[1][1].inverse_transform(preds['counts'][i])) - 1
for i, task in enumerate(tasks)}
def append_scaled_profile(preds):
scales = pred2scale_strands(preds)
preds['scaled_profile'] = {task: preds['profile'][i] * scales[task][np.newaxis]
for i, task in enumerate(tasks)}
return preds
out = [dict(
seq=get_dataset_item(seqs, i),
preds=append_scaled_profile(get_dataset_item(preds_dict, i)),
# scaling factor to go from relative -> absolute counts
scale_profile=pred2scale_strands(get_dataset_item(preds_dict, i)),
) for i in range(len(seqs))]
return out
y_preds = bpnet_predict(onehot_data)
def plot_preds(y_preds, num_samples=5, figsize=(20, 2), binsize=1):
import matplotlib.pyplot as plt
from basepair.preproc import bin_counts
# for visualization, we use bucketize
for idx in range(num_samples):
fig = plt.figure(figsize=figsize)
plt.subplot(141)
if idx == 0:
plt.title("Oct4")
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Oct4'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Oct4'][:, 0])))
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Oct4'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Oct4'][:, 1])))
plt.legend()
plt.subplot(142)
if idx == 0:
plt.title("Sox2")
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Sox2'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Sox2'][:, 0])))
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Sox2'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Sox2'][:, 1])))
plt.legend()
plt.subplot(143)
if idx == 0:
plt.title("Klf4")
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Klf4'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Klf4'][:, 0])))
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Klf4'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Klf4'][:, 1])))
plt.legend()
plt.subplot(144)
if idx == 0:
plt.title("Nanog")
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Nanog'], binsize=binsize)[:, 0], label='pos,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Nanog'][:, 0])))
plt.plot(bin_counts(y_preds[idx]['preds']['scaled_profile']['Nanog'], binsize=binsize)[:, 1], label='neg,m={}'.format(np.argmax(y_preds[idx]['preds']['scaled_profile']['Nanog'][:, 1])))
plt.legend()
plot_preds(y_preds)
def onehot_to_seq(onehot):
seq = ""
for pos in range(len(onehot)):
char_idx = [i for i, e in enumerate(onehot[pos]) if e != 0][0]
if char_idx == 0:
char = 'A'
elif char_idx == 1:
char = "C"
elif char_idx == 2:
char = "G"
elif char_idx == 3:
char = "T"
seq += char
return seq
#this is set up for 1d convolutions where examples
#have dimensions (len, num_channels)
#the channel axis is the axis for one-hot encoding.
def one_hot_encode_along_channel_axis(sequence):
to_return = np.zeros((len(sequence),4), dtype=np.int8)
seq_to_one_hot_fill_in_array(zeros_array=to_return,
sequence=sequence, one_hot_axis=1)
return to_return
def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
assert one_hot_axis==0 or one_hot_axis==1
if (one_hot_axis==0):
assert zeros_array.shape[1] == len(sequence)
elif (one_hot_axis==1):
assert zeros_array.shape[0] == len(sequence)
#will mutate zeros_array
for (i,char) in enumerate(sequence):
if (char=="A" or char=="a"):
char_idx = 0
elif (char=="C" or char=="c"):
char_idx = 1
elif (char=="G" or char=="g"):
char_idx = 2
elif (char=="T" or char=="t"):
char_idx = 3
elif (char=="N" or char=="n"):
continue #leave that pos as all 0's
else:
raise RuntimeError("Unsupported character: "+str(char))
if (one_hot_axis==0):
zeros_array[char_idx,i] = 1
elif (one_hot_axis==1):
zeros_array[i,char_idx] = 1
from deeplift.dinuc_shuffle import dinuc_shuffle
num_samples = 5
num_shuffles = 50
for idx in range(num_samples):
ref_pred = {}
ref_pred[0] = {}
ref_pred[0]['preds'] = {}
ref_pred[0]['preds']['scaled_profile'] = {}
ref_pred[0]['preds']['scaled_profile']['Oct4'] = np.zeros((200, 2))
ref_pred[0]['preds']['scaled_profile']['Sox2'] = np.zeros((200, 2))
ref_pred[0]['preds']['scaled_profile']['Klf4'] = np.zeros((200, 2))
ref_pred[0]['preds']['scaled_profile']['Nanog'] = np.zeros((200, 2))
for shuff in range(num_shuffles):
reference = one_hot_encode_along_channel_axis(dinuc_shuffle(onehot_to_seq(onehot_data[idx])))
pred = bpnet_predict([[reference]])
ref_pred[0]['preds']['scaled_profile']['Oct4'] += pred[0]['preds']['scaled_profile']['Oct4']
ref_pred[0]['preds']['scaled_profile']['Sox2'] += pred[0]['preds']['scaled_profile']['Sox2']
ref_pred[0]['preds']['scaled_profile']['Klf4'] += pred[0]['preds']['scaled_profile']['Klf4']
ref_pred[0]['preds']['scaled_profile']['Nanog'] += pred[0]['preds']['scaled_profile']['Nanog']
ref_pred[0]['preds']['scaled_profile']['Oct4'] /= num_shuffles
ref_pred[0]['preds']['scaled_profile']['Sox2'] /= num_shuffles
ref_pred[0]['preds']['scaled_profile']['Klf4'] /= num_shuffles
ref_pred[0]['preds']['scaled_profile']['Nanog'] /= num_shuffles
# reference = np.zeros((200,4))#np.array([[0.3, 0.2, 0.2, 0.3]]*200)
# ref_pred = np.array(bpnet_predict([[reference]]))
plot_preds(ref_pred, num_samples=1)
def plot_truth(y, num_samples=5, figsize=(20, 2), binsize=1):
import matplotlib.pyplot as plt
from basepair.preproc import bin_counts
# for visualization, we use bucketize
for idx in range(num_samples):
fig = plt.figure(figsize=figsize)
plt.subplot(141)
if idx == 0:
plt.title("Oct4")
plt.plot(bin_counts(y['profile/Oct4'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Oct4'][idx, :, 0])))
plt.plot(bin_counts(y['profile/Oct4'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Oct4'][idx, :, 1])))
plt.legend()
plt.subplot(142)
if idx == 0:
plt.title("Sox2")
plt.plot(bin_counts(y['profile/Sox2'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Sox2'][idx, :, 0])))
plt.plot(bin_counts(y['profile/Sox2'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Sox2'][idx, :, 1])))
plt.legend()
plt.subplot(143)
if idx == 0:
plt.title("Klf4")
plt.plot(bin_counts(y['profile/Klf4'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Klf4'][idx, :, 0])))
plt.plot(bin_counts(y['profile/Klf4'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Klf4'][idx, :, 1])))
plt.legend()
plt.subplot(144)
if idx == 0:
plt.title("Nanog")
plt.plot(bin_counts(y['profile/Nanog'], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(y['profile/Nanog'][idx, :, 0])))
plt.plot(bin_counts(y['profile/Nanog'], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(y['profile/Nanog'][idx, :, 1])))
plt.legend()
plot_truth(valid[1])