import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from motif.read_motifs import pfm_info_content, pfm_to_pwm
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import h5py
import numpy as np
import pyfaidx
import matplotlib.pyplot as plt
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
# Define parameters/fetch arguments
filter_activations_path = os.environ["TFM_FILTER_ACTIVATIONS"]
filter_weights_path = os.environ["TFM_FILTER_WEIGHTS"]
if "TFM_MOTIF_CACHE" in os.environ:
activation_motifs_cache_dir = os.environ["TFM_MOTIF_CACHE"]
else:
activation_motifs_cache_dir = None
print("Path to filter activations: %s" % filter_activations_path)
print("Path to filter weights: %s" % filter_weights_path)
print("Saved activation-derived motifs cache: %s" % activation_motifs_cache_dir)
# Constants/paths
input_length = 2114
filter_width = 21
reference_genome_path = "/users/amtseng/genomes/hg38.fasta"
if activation_motifs_cache_dir:
os.makedirs(activation_motifs_cache_dir, exist_ok=True)
For extracting motifs
def dna_to_one_hot(seqs):
"""
Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the
position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list
of N strings, where every string is the same length L. Returns an N x L x 4
NumPy array of one-hot encodings, in the same order as the input sequences.
All bases will be converted to upper-case prior to performing the encoding.
Any bases that are not "ACGT" will be given an encoding of all 0s.
"""
seq_len = len(seqs[0])
assert np.all(np.array([len(s) for s in seqs]) == seq_len)
# Join all sequences together into one long string, all uppercase
seq_concat = "".join(seqs).upper()
one_hot_map = np.identity(5)[:, :-1]
# Convert string into array of ASCII character codes;
base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8)
# Anything that's not an A, C, G, or T gets assigned a higher code
base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85
# Convert the codes into indices in [0, 4], in ascending order by code
_, base_inds = np.unique(base_vals, return_inverse=True)
# Get the one-hot encoding for those indices, and reshape back to separate
return one_hot_map[base_inds].reshape((len(seqs), seq_len, 4))
def extract_filter_activation_motifs(filter_activations_path, reference_genome_path):
"""
Extracts the motifs that correspond to each filter. Returns an
F x W x 4 array, where F is the number of filters and W is the width
of each filter. The order of filters matches those in the saved HDF5/model.
"""
reader = h5py.File(filter_activations_path, "r")
activations_reader = reader["activations"]
num_coords, two, num_windows, num_filters = activations_reader.shape
assert two == 2
assert num_windows == input_length - filter_width + 1
print("Importing coordinates...")
coords = np.empty((num_coords, 3), dtype=object)
coords[:, 0] = reader["coords"]["coords_chrom"][:].astype(str)
coords[:, 1] = reader["coords"]["coords_start"][:]
coords[:, 2] = reader["coords"]["coords_end"][:]
print("Fetching one-hot sequences...")
genome_reader = pyfaidx.Fasta(reference_genome_path)
one_hot_seqs = np.empty((num_coords, input_length, 4))
batch_size = 128
num_batches = int(np.ceil(num_coords / batch_size))
for i in tqdm.notebook.trange(num_batches):
batch_slice = slice(i * batch_size, (i + 1) * batch_size)
one_hot_seqs[batch_slice] = dna_to_one_hot([
genome_reader[chrom][start:end].seq for chrom, start, end in coords[batch_slice]
])
pfms = np.empty((num_filters, filter_width, 4))
for filter_index in range(num_filters):
print("Extracting motif for filter %d..." % filter_index)
print("\tComputing maximum activation...")
acts = activations_reader[:, :, :, filter_index]
max_act = np.max(acts)
inds = np.where(acts >= 0.5 * max_act)
windows, num_windows = np.zeros((filter_width, 4)), 0
for coord_index, strand_index, pos_index in tqdm.notebook.tqdm(
zip(*inds), total=len(inds[0]), desc="Extracting windows..."
):
if strand_index == 0:
window = one_hot_seqs[coord_index, pos_index : pos_index + filter_width]
else:
# Reverse complement; the positions are flipped
window = np.flip(
one_hot_seqs[coord_index, input_length - filter_width - pos_index : input_length - pos_index],
axis=(0, 1)
)
windows = windows + window
num_windows += 1
pfms[filter_index] = windows / num_windows
return pfms
def compute_filter_influence(filter_activations_path):
"""
Extracts the influence of each filter by computing the difference
in cross entropy when each filter is nullified.
Returns an F-array, where F is the number of filters, containing the
change in average cross entropy (after nullification - before
nullification). The order of filters matches those in the saved
HDF5/model.
"""
reader = h5py.File(filter_activations_path, "r")
print("Reading in cross entropies...")
before_null_cross_ents = reader["predictions"]["cross_ents"][:]
after_null_cross_ents = reader["nullified_predictions"]["cross_ents"][:]
before_null = np.nanmean(before_null_cross_ents)
num_filters = after_null_cross_ents.shape[1]
influences = []
for filter_index in tqdm.notebook.trange(num_filters):
after_null = np.nanmean(after_null_cross_ents[:, filter_index])
influences.append(after_null - before_null)
return np.array(influences)
def save_activation_motifs(filter_pfms, filter_influences, path):
"""
Saves the filter-activation-derived PFMs and influence values.
"""
with h5py.File(path, "w") as f:
f.create_dataset("pfms", data=filter_pfms, compression="gzip")
f.create_dataset("influences", data=filter_influences, compression="gzip")
def load_activation_motifs(path):
"""
Loads the filter-activation-derived PFMs and influence values.
"""
with h5py.File(path, "r") as f:
return f["pfms"][:], f["influences"][:]
Extract the motifs derived from each filter, ranked by filter influence.
Deriving a filter's motif:
Deriving a filter's influence:
compute_motifs = True
if activation_motifs_cache_dir:
# Import if it exists
cache_path = os.path.join(activation_motifs_cache_dir, "filter_activation_motifs.h5")
if os.path.exists(cache_path) and os.stat(cache_path).st_size:
filter_pfms, filter_influences = load_activation_motifs(cache_path)
compute_motifs = False
if compute_motifs:
# Extract PFMs of highly-activating sequences
filter_pfms = extract_filter_activation_motifs(filter_activations_path, reference_genome_path)
# Compute influence of each filter
filter_influences = compute_filter_influence(filter_activations_path)
if activation_motifs_cache_dir:
save_activation_motifs(filter_pfms, filter_influences, cache_path)
# Import the filter weights themselves
filter_weights = np.load(filter_weights_path)
assert len(filter_weights.shape) == 3
assert filter_weights.shape[:2] == (filter_width, 4)
filter_weights = np.transpose(filter_weights, axes=(2, 0, 1)) # Shape: F x W x 4
For each filter, its motif is constructed by averaging all of the sequences that activate it at least to half of its maximal activation. We show the PWMs. The filters are ranked by influence (i.e. the average difference in prediction cross entropy when the filter is nullified--that is, replaced with its average activation).
colgroup = vdomh.colgroup(
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "85%"})
)
header = vdomh.thead(
vdomh.tr(
vdomh.th("Rank", style={"text-align": "center"}),
vdomh.th("Filter index", style={"text-align": "center"}),
vdomh.th("Influence", style={"text-align": "center"}),
vdomh.th("PWM", style={"text-align": "center"})
)
)
body = []
for i, filter_index in enumerate(np.flip(np.argsort(filter_influences))):
pwm = pfm_to_pwm(filter_pfms[filter_index])
if np.sum(pwm[:, [0, 2]]) < 0.5 * np.sum(pwm):
# Flip to purine-rich version
pwm = np.flip(pwm, axis=(0, 1))
fig = viz_sequence.plot_weights(pwm, figsize=(20, 4), return_fig=True)
fig.tight_layout()
body.append(
vdomh.tr(
vdomh.td(str(i + 1)),
vdomh.td(str(filter_index)),
vdomh.td("%.3f" % filter_influences[filter_index]),
vdomh.td(figure_to_vdom_image(fig))
)
)
if activation_motifs_cache_dir:
# Save motif PWM
fig.savefig(os.path.join(activation_motifs_cache_dir, "filter_activation_motif_%d.png" % filter_index))
display(vdomh.table(colgroup, header, vdomh.tbody(*body)))
plt.close("all")
For each filter, we show its corresponding motif simply as the mean-normalized multiplicative weights in the filter. For consistency, we rank the filters by influence (as above).
colgroup = vdomh.colgroup(
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "5%"}),
vdomh.col(style={"width": "85%"})
)
header = vdomh.thead(
vdomh.tr(
vdomh.th("Rank", style={"text-align": "center"}),
vdomh.th("Filter index", style={"text-align": "center"}),
vdomh.th("Influence", style={"text-align": "center"}),
vdomh.th("Mean-normalized filter weights", style={"text-align": "center"})
)
)
body = []
for i, filter_index in enumerate(np.flip(np.argsort(filter_influences))):
weights = filter_weights[filter_index]
weights = weights - np.mean(weights, axis=1, keepdims=True)
if np.sum(weights[:, [0, 2]]) < 0.5 * np.sum(weights):
# Flip to purine-rich version
weights = np.flip(weights, axis=(0, 1))
fig = viz_sequence.plot_weights(weights, figsize=(20, 4), return_fig=True)
fig.tight_layout()
body.append(
vdomh.tr(
vdomh.td(str(i + 1)),
vdomh.td(str(filter_index)),
vdomh.td("%.3f" % filter_influences[filter_index]),
vdomh.td(figure_to_vdom_image(fig))
)
)
if activation_motifs_cache_dir:
# Save motif PWM
fig.savefig(os.path.join(activation_motifs_cache_dir, "filter_weight_motif_%d.png" % filter_index))
display(vdomh.table(colgroup, header, vdomh.tbody(*body)))
plt.close("all")