import deepdish as dd
import json
import numpy as np
import tensorflow as tf
import pandas as pd
import shap
from tensorflow import keras
import pyfaidx
import shutil
import errno
import os
import sys
sys.path.append("../../../lib/chrombpnet/src/")
from training.utils.losses import multinomial_nll
from training.utils.one_hot import dna_to_one_hot
from evaluation.interpret.shap_utils import *
sys.path.append("../../")
from cross_species_models.utils.viz_sequence import *
import matplotlib.pyplot as plt
import seaborn as sns
tf.compat.v1.disable_eager_execution()
SEQ_LEN = 2114
assay = "ENCSR460DKJ"
out_dir = "/oak/stanford/groups/akundaje/patelas/cross_species_models/experiments/20220509_chen_lab_collab/results/human_promoter_subs_atlas/"
species_list = pd.read_csv("/users/patelas/cross_species_models/experiments/20220509_chen_lab_collab/files/task_list.txt", sep="\t", header=None)
promoter_list = pd.read_csv("/users/patelas/cross_species_models/experiments/20220509_chen_lab_collab/files/promoter_list.txt", sep="\t", header=None)
human_fasta = pyfaidx.Fasta("/oak/stanford/groups/akundaje/patelas/reference/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta", sequence_always_upper=True)
model_file = "/oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/ATAC/%s/chrombpnet_model_feb15/chrombpnet_wo_bias.h5"%(assay)
single_motifers = ["chimpanzee", "human", "gorilla", "orangutan", "olive_baboon", "rhesus", "mas_night_monkey",
"porcupine", "paca", "naked_mole-rat", "nutria", "degu", "capybara", "american_beaver"]
double_motifers = ["medium_lemur", "galago", "algerian_mouse", "house_mouse", "ryukyu_mouse",
"shrewmouse", "rat", "deer_mouse", "muskrat", "golden_hamster", "woodchuck", "ground_squirrel",
"egyptian_jerboa", "pika", "blind_mole-rat"]
predict_dir = "/oak/stanford/groups/akundaje/patelas/cross_species_models/experiments/20220509_chen_lab_collab/results/human_promoter_subs_atlas/"
preds = pd.read_csv(predict_dir + assay + ".tsv", header=None, sep="\t").applymap(np.log)
def get_seq_for_predict(tert_prom, fasta):
'''
Basically, we need to create a sequence of length 2114
We take the TERT promoter we have, and we extend it on both sides using the real human genome sequence flanking the human promoter
This is essentially equivalent to replacing the human promoter with another species' promoter and padding using the human genome to reach 2114
'''
leftover_len = (SEQ_LEN - len(tert_prom))
if leftover_len > 0:
left_flanker = fasta["chr5"][1294899 - leftover_len // 2:1294899].seq
if leftover_len %2:
right_flanker = fasta["chr5"][1295300 : 1295300 + leftover_len // 2 + 1].seq
else:
right_flanker = fasta["chr5"][1295300 : 1295300 + leftover_len // 2].seq
final_seq = left_flanker + tert_prom + right_flanker
else:
if SEQ_LEN % 2:
final_seq = tert_prom[len(tert_prom) // 2 - SEQ_LEN // 2 : len(tert_prom) // 2 + SEQ_LEN // 2 + 1]
else:
final_seq = tert_prom[len(tert_prom) // 2 - SEQ_LEN // 2 : len(tert_prom) // 2 + SEQ_LEN // 2]
assert len(final_seq) == SEQ_LEN
return final_seq
def generate_shap_dict(seqs, scores):
print(seqs.shape, scores.shape)
assert(seqs.shape==scores.shape)
assert(seqs.shape[2]==4)
# construct a dictionary for the raw shap scores and the
# the projected shap scores
# MODISCO workflow expects one hot sequences with shape (None,4,inputlen)
d = {
'raw': {'seq': np.transpose(seqs, (0, 2, 1))},
'shap': {'seq': np.transpose(scores, (0, 2, 1))},
'projected_shap': {'seq': np.transpose(seqs*scores, (0, 2, 1))}
}
return d
def softmax(x, temp=1):
norm_x = x - np.mean(x,axis=1, keepdims=True)
return np.exp(temp*norm_x)/np.sum(np.exp(temp*norm_x), axis=1, keepdims=True)
with keras.utils.CustomObjectScope({'multinomial_nll':multinomial_nll, 'tf':tf}):
model = keras.models.load_model(model_file)
one_hot_seqs = dna_to_one_hot([get_seq_for_predict(promoter_list.loc[spec, 0], human_fasta) for spec in range(len(species_list))])
outlen = model.output_shape[0][1]
profile_model_input = model.input
profile_input = one_hot_seqs
counts_model_input = model.input
counts_input = one_hot_seqs
Here, we plot predictions for all species from the relevant Atlas model. Single-motifers are depicted in orange, and double-motifers are depicted in green. Because predicted count values may vary wildly in magnitude, plots are in log space.
def plot_preds_separate(preds, assay, task_list):
plt.rc('axes', labelsize=6)
plt.rc('ytick', labelsize=6)
plt.rc('xtick', labelsize=6)
# plt.rc('title', labelsize=8)
maxval = preds.max().max()
for ind, spec in enumerate(task_list):
plt.figure(dpi=300, figsize=(5,1))
df_index = list(species_list[0]).index(spec)
if spec in single_motifers:
sns.lineplot(x=np.arange(1000), y=preds.loc[df_index], alpha=0.7, color="orange")
else:
sns.lineplot(x=np.arange(1000), y=preds.loc[df_index], alpha=0.7, color="green")
plt.ylim(0, maxval)
plt.title("Predictions for %s - %s"%(assay, spec.upper()), size=8)
plt.ylabel("Predicted Log Counts")
plt.show()
return
plot_preds_separate(preds, assay, single_motifers + double_motifers)
In this section, we plot SHAP scores, which reflect the importance of each input nucleotide to the model's prediction. Because the model predicts two quantities (probability profile and total counts), we have two sets of scores. These scores are very helpful in identifying sequence motifs that influenced the predicted output. For each species, the corresponding plots are of the 100bp window around the nucleotide of highest importance.
profile_model_counts_explainer = shap.explainers.deep.TFDeepExplainer(
(counts_model_input, tf.reduce_sum(model.outputs[1], axis=-1)),
shuffle_several_times,
combine_mult_and_diffref=combine_mult_and_diffref)
counts_shap_scores = profile_model_counts_explainer.shap_values(
counts_input, progress_message=100)
print(counts_shap_scores.shape)
# counts_shap_scores = counts_shap_scores[0]
counts_scores_dict = generate_shap_dict(one_hot_seqs, counts_shap_scores)
weightedsum_meannormed_logits = get_weightedsum_meannormed_logits(model)
profile_model_profile_explainer = shap.explainers.deep.TFDeepExplainer(
(profile_model_input, weightedsum_meannormed_logits),
shuffle_several_times,
combine_mult_and_diffref=combine_mult_and_diffref)
profile_shap_scores = profile_model_profile_explainer.shap_values(
profile_input, progress_message=100)
profile_scores_dict = generate_shap_dict(one_hot_seqs, profile_shap_scores)
We first plot the importance scores for the profile prediction.
max_score = profile_scores_dict["projected_shap"]["seq"].max()
for ind, spec in enumerate(single_motifers + double_motifers):
spec_index = list(species_list[0]).index(spec)
curr_argmax = profile_scores_dict["projected_shap"]["seq"][spec_index].max(axis=0).argmax()
plot_weights(profile_scores_dict["projected_shap"]["seq"][spec_index][:,curr_argmax - 50 : curr_argmax + 50], subticks_frequency=100)
plt.title("Profile SHAP - %s"%(species_list.loc[spec_index, 0].upper()))
plt.ylim(0, max_score)
plt.show()
Next, we plot the importance scores for the count prediction.
max_score = counts_scores_dict["projected_shap"]["seq"].max()
for ind, spec in enumerate(single_motifers + double_motifers):
spec_index = list(species_list[0]).index(spec)
curr_argmax = counts_scores_dict["projected_shap"]["seq"][spec_index].max(axis=0).argmax()
plot_weights(counts_scores_dict["projected_shap"]["seq"][spec_index][:,curr_argmax - 50 : curr_argmax + 50], subticks_frequency=100)
plt.title("Counts SHAP - %s"%(species_list.loc[spec_index, 0].upper()))
plt.ylim(0, max_score)
plt.show()