In [1]:
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from motif.read_motifs import pfm_info_content, pfm_to_pwm
from util import figure_to_vdom_image
import plot.viz_sequence as viz_sequence
import h5py
import numpy as np
import pyfaidx
import matplotlib.pyplot as plt
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:14: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  
Out[1]:
0it [00:00, ?it/s]

Define constants and paths

In [2]:
# Define parameters/fetch arguments
filter_activations_path = os.environ["TFM_FILTER_ACTIVATIONS"]
filter_weights_path = os.environ["TFM_FILTER_WEIGHTS"]

if "TFM_MOTIF_CACHE" in os.environ:
    activation_motifs_cache_dir = os.environ["TFM_MOTIF_CACHE"]
else:
    activation_motifs_cache_dir = None

print("Path to filter activations: %s" % filter_activations_path)
print("Path to filter weights: %s" % filter_weights_path)
print("Saved activation-derived motifs cache: %s" % activation_motifs_cache_dir)
Path to filter activations: /users/amtseng/tfmodisco/results/filter_activations/singletask_profile_finetune/NR3C1-reddytime_singletask_profile_finetune_task14_fold5_filter_activations.h5
Path to filter weights: /users/amtseng/tfmodisco/results/filter_weights/singletask_profile_finetune/NR3C1-reddytime_singletask_profile_finetune_task14_fold5_filter_weights.npy
Saved activation-derived motifs cache: /users/amtseng/tfmodisco/results/reports/filter_derived_motifs/cache/NR3C1-reddytime_singletask_profile_finetune_task14_fold5_filter
In [3]:
# Constants/paths
input_length = 2114
filter_width = 21
reference_genome_path = "/users/amtseng/genomes/hg38.fasta"
In [4]:
if activation_motifs_cache_dir:
    os.makedirs(activation_motifs_cache_dir, exist_ok=True)

Helper functions

For extracting motifs

In [5]:
def dna_to_one_hot(seqs):
    """
    Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the
    position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list
    of N strings, where every string is the same length L. Returns an N x L x 4
    NumPy array of one-hot encodings, in the same order as the input sequences.
    All bases will be converted to upper-case prior to performing the encoding.
    Any bases that are not "ACGT" will be given an encoding of all 0s.
    """
    seq_len = len(seqs[0])
    assert np.all(np.array([len(s) for s in seqs]) == seq_len)

    # Join all sequences together into one long string, all uppercase
    seq_concat = "".join(seqs).upper()

    one_hot_map = np.identity(5)[:, :-1]

    # Convert string into array of ASCII character codes;
    base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8)

    # Anything that's not an A, C, G, or T gets assigned a higher code
    base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85

    # Convert the codes into indices in [0, 4], in ascending order by code
    _, base_inds = np.unique(base_vals, return_inverse=True)

    # Get the one-hot encoding for those indices, and reshape back to separate
    return one_hot_map[base_inds].reshape((len(seqs), seq_len, 4))
In [6]:
def extract_filter_activation_motifs(filter_activations_path, reference_genome_path):
    """
    Extracts the motifs that correspond to each filter. Returns an
    F x W x 4 array, where F is the number of filters and W is the width
    of each filter. The order of filters matches those in the saved HDF5/model.
    """
    reader = h5py.File(filter_activations_path, "r")
    activations_reader = reader["activations"]
    num_coords, two, num_windows, num_filters = activations_reader.shape
    
    assert two == 2
    assert num_windows == input_length - filter_width + 1
    
    print("Importing coordinates...")
    coords = np.empty((num_coords, 3), dtype=object)
    coords[:, 0] = reader["coords"]["coords_chrom"][:].astype(str)
    coords[:, 1] = reader["coords"]["coords_start"][:]
    coords[:, 2] = reader["coords"]["coords_end"][:]
    
    print("Fetching one-hot sequences...")
    genome_reader = pyfaidx.Fasta(reference_genome_path)
    one_hot_seqs = np.empty((num_coords, input_length, 4))
    batch_size = 128
    num_batches = int(np.ceil(num_coords / batch_size))
    for i in tqdm.notebook.trange(num_batches):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        one_hot_seqs[batch_slice] = dna_to_one_hot([
            genome_reader[chrom][start:end].seq for chrom, start, end in coords[batch_slice]
        ])
    
    pfms = np.empty((num_filters, filter_width, 4))
    for filter_index in range(num_filters):
        print("Extracting motif for filter %d..." % filter_index)
    
        print("\tComputing maximum activation...")
        acts = activations_reader[:, :, :, filter_index]
        max_act = np.max(acts)
        
        inds = np.where(acts >= 0.5 * max_act)
        
        windows, num_windows = np.zeros((filter_width, 4)), 0
        for coord_index, strand_index, pos_index in tqdm.notebook.tqdm(
            zip(*inds), total=len(inds[0]), desc="Extracting windows..."
        ):
            if strand_index == 0:
                window = one_hot_seqs[coord_index, pos_index : pos_index + filter_width]
            else:
                # Reverse complement; the positions are flipped
                window = np.flip(
                    one_hot_seqs[coord_index, input_length - filter_width - pos_index : input_length - pos_index],
                    axis=(0, 1)
                )
            windows = windows + window
            num_windows += 1
        
        pfms[filter_index] = windows / num_windows
    
    return pfms
In [7]:
def compute_filter_influence(filter_activations_path):
    """
    Extracts the influence of each filter by computing the difference
    in cross entropy when each filter is nullified.
    Returns an F-array, where F is the number of filters, containing the
    change in average cross entropy (after nullification - before
    nullification). The order of filters matches those in the saved
    HDF5/model.
    """
    reader = h5py.File(filter_activations_path, "r")
    print("Reading in cross entropies...")
    before_null_cross_ents = reader["predictions"]["cross_ents"][:]
    after_null_cross_ents = reader["nullified_predictions"]["cross_ents"][:]
    
    before_null = np.nanmean(before_null_cross_ents)
    
    num_filters = after_null_cross_ents.shape[1]
    
    influences = []
    for filter_index in tqdm.notebook.trange(num_filters):
        after_null = np.nanmean(after_null_cross_ents[:, filter_index])
        influences.append(after_null - before_null)
        
    return np.array(influences)
In [8]:
def save_activation_motifs(filter_pfms, filter_influences, path):
    """
    Saves the filter-activation-derived PFMs and influence values.
    """
    with h5py.File(path, "w") as f:
        f.create_dataset("pfms", data=filter_pfms, compression="gzip")
        f.create_dataset("influences", data=filter_influences, compression="gzip")
In [9]:
def load_activation_motifs(path):
    """
    Loads the filter-activation-derived PFMs and influence values.
    """
    with h5py.File(path, "r") as f:
        return f["pfms"][:], f["influences"][:]

Extract motifs from filter activations

Extract the motifs derived from each filter, ranked by filter influence.

Deriving a filter's motif:

  1. Identify the top 10000 most well-predicted input sequences, ranked by cross entropy
  2. For each window in each of these sequences, compute the filter activation for each 1st-layer filter
  3. A filter's motif is the aggregation of sequence windows which activate that filter to at least half its maximum activation (over the top 10000 most well-predicted inputs)

Deriving a filter's influence:

  1. Identify the top 10000 most well-predicted input sequences, ranked by cross entropy
  2. Nullify each filter by setting it to the average activation over these 10000 most well-predicted inputs
  3. A filter's influence is the average change in cross entropy before and after nullification
In [10]:
compute_motifs = True
if activation_motifs_cache_dir:
    # Import if it exists
    cache_path = os.path.join(activation_motifs_cache_dir, "filter_activation_motifs.h5")
    if os.path.exists(cache_path) and os.stat(cache_path).st_size:
        filter_pfms, filter_influences = load_activation_motifs(cache_path)
        compute_motifs = False

if compute_motifs:
    # Extract PFMs of highly-activating sequences
    filter_pfms = extract_filter_activation_motifs(filter_activations_path, reference_genome_path)

    # Compute influence of each filter
    filter_influences = compute_filter_influence(filter_activations_path)

    if activation_motifs_cache_dir:
        save_activation_motifs(filter_pfms, filter_influences, cache_path)
Importing coordinates...
Fetching one-hot sequences...
Extracting motif for filter 0...
	Computing maximum activation...
Extracting motif for filter 1...
	Computing maximum activation...
Extracting motif for filter 2...
	Computing maximum activation...
Extracting motif for filter 3...
	Computing maximum activation...
Extracting motif for filter 4...
	Computing maximum activation...
Extracting motif for filter 5...
	Computing maximum activation...
Extracting motif for filter 6...
	Computing maximum activation...
Extracting motif for filter 7...
	Computing maximum activation...
Extracting motif for filter 8...
	Computing maximum activation...
Extracting motif for filter 9...
	Computing maximum activation...
Extracting motif for filter 10...
	Computing maximum activation...
Extracting motif for filter 11...
	Computing maximum activation...
Extracting motif for filter 12...
	Computing maximum activation...
Extracting motif for filter 13...
	Computing maximum activation...
Extracting motif for filter 14...
	Computing maximum activation...
Extracting motif for filter 15...
	Computing maximum activation...
Extracting motif for filter 16...
	Computing maximum activation...
Extracting motif for filter 17...
	Computing maximum activation...
Extracting motif for filter 18...
	Computing maximum activation...
Extracting motif for filter 19...
	Computing maximum activation...
Extracting motif for filter 20...
	Computing maximum activation...
Extracting motif for filter 21...
	Computing maximum activation...
Extracting motif for filter 22...
	Computing maximum activation...
Extracting motif for filter 23...
	Computing maximum activation...
Extracting motif for filter 24...
	Computing maximum activation...
Extracting motif for filter 25...
	Computing maximum activation...
Extracting motif for filter 26...
	Computing maximum activation...
Extracting motif for filter 27...
	Computing maximum activation...
Extracting motif for filter 28...
	Computing maximum activation...
Extracting motif for filter 29...
	Computing maximum activation...
Extracting motif for filter 30...
	Computing maximum activation...
Extracting motif for filter 31...
	Computing maximum activation...
Extracting motif for filter 32...
	Computing maximum activation...
Extracting motif for filter 33...
	Computing maximum activation...
Extracting motif for filter 34...
	Computing maximum activation...
Extracting motif for filter 35...
	Computing maximum activation...
Extracting motif for filter 36...
	Computing maximum activation...
Extracting motif for filter 37...
	Computing maximum activation...
Extracting motif for filter 38...
	Computing maximum activation...
Extracting motif for filter 39...
	Computing maximum activation...
Extracting motif for filter 40...
	Computing maximum activation...
Extracting motif for filter 41...
	Computing maximum activation...
Extracting motif for filter 42...
	Computing maximum activation...
Extracting motif for filter 43...
	Computing maximum activation...
Extracting motif for filter 44...
	Computing maximum activation...
Extracting motif for filter 45...
	Computing maximum activation...
Extracting motif for filter 46...
	Computing maximum activation...
Extracting motif for filter 47...
	Computing maximum activation...
Extracting motif for filter 48...
	Computing maximum activation...
Extracting motif for filter 49...
	Computing maximum activation...
Extracting motif for filter 50...
	Computing maximum activation...
Extracting motif for filter 51...
	Computing maximum activation...
Extracting motif for filter 52...
	Computing maximum activation...
Extracting motif for filter 53...
	Computing maximum activation...
Extracting motif for filter 54...
	Computing maximum activation...
Extracting motif for filter 55...
	Computing maximum activation...
Extracting motif for filter 56...
	Computing maximum activation...
Extracting motif for filter 57...
	Computing maximum activation...
Extracting motif for filter 58...
	Computing maximum activation...
Extracting motif for filter 59...
	Computing maximum activation...
Extracting motif for filter 60...
	Computing maximum activation...
Extracting motif for filter 61...
	Computing maximum activation...
Extracting motif for filter 62...
	Computing maximum activation...
Extracting motif for filter 63...
	Computing maximum activation...
Reading in cross entropies...

Extract motifs from filter weights

In [11]:
# Import the filter weights themselves
filter_weights = np.load(filter_weights_path)
assert len(filter_weights.shape) == 3
assert filter_weights.shape[:2] == (filter_width, 4)
filter_weights = np.transpose(filter_weights, axes=(2, 0, 1))  # Shape: F x W x 4

Motifs derived from filter-activating sequences

For each filter, its motif is constructed by averaging all of the sequences that activate it at least to half of its maximal activation. We show the PWMs. The filters are ranked by influence (i.e. the average difference in prediction cross entropy when the filter is nullified--that is, replaced with its average activation).

In [12]:
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "85%"})
)
header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Rank", style={"text-align": "center"}),
        vdomh.th("Filter index", style={"text-align": "center"}),
        vdomh.th("Influence", style={"text-align": "center"}),
        vdomh.th("PWM", style={"text-align": "center"})
    )
)

body = []
for i, filter_index in enumerate(np.flip(np.argsort(filter_influences))):
    pwm = pfm_to_pwm(filter_pfms[filter_index])
    if np.sum(pwm[:, [0, 2]]) < 0.5 * np.sum(pwm):
        # Flip to purine-rich version
        pwm = np.flip(pwm, axis=(0, 1))
    fig = viz_sequence.plot_weights(pwm, figsize=(20, 4), return_fig=True)
    fig.tight_layout()
    
    body.append(
        vdomh.tr(
            vdomh.td(str(i + 1)),
            vdomh.td(str(filter_index)),
            vdomh.td("%.3f" % filter_influences[filter_index]),
            vdomh.td(figure_to_vdom_image(fig))
        )
    )
    
    if activation_motifs_cache_dir:
        # Save motif PWM
        fig.savefig(os.path.join(activation_motifs_cache_dir, "filter_activation_motif_%d.png" % filter_index))

display(vdomh.table(colgroup, header, vdomh.tbody(*body)))
plt.close("all")
/users/amtseng/tfmodisco/src/plot/viz_sequence.py:152: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  fig = plt.figure(figsize=figsize)
RankFilter indexInfluencePWM
1320.030
260.021
380.017
4520.015
5440.012
6370.011
7540.011
8260.010
9400.009
10590.009
11600.008
12140.007
1320.007
14530.007
15450.006
16170.006
17430.006
18280.006
19550.005
20150.005
2150.005
22240.005
23200.005
2400.005
2590.005
26410.004
27620.004
28100.004
29130.004
30220.004
31420.004
32330.003
33360.003
34180.003
35580.003
3670.003
37470.003
38340.003
39460.003
40500.003
41570.002
42210.002
43190.002
44560.002
45230.002
46120.002
47300.002
48310.002
4940.002
50290.002
5110.002
52390.002
53160.002
54610.002
55110.002
56630.001
57380.001
58480.001
59510.001
60490.001
61350.001
62250.001
63270.001
6430.000

Motifs derived from filter weights

For each filter, we show its corresponding motif simply as the mean-normalized multiplicative weights in the filter. For consistency, we rank the filters by influence (as above).

In [13]:
colgroup = vdomh.colgroup(
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "5%"}),
    vdomh.col(style={"width": "85%"})
)
header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Rank", style={"text-align": "center"}),
        vdomh.th("Filter index", style={"text-align": "center"}),
        vdomh.th("Influence", style={"text-align": "center"}),
        vdomh.th("Mean-normalized filter weights", style={"text-align": "center"})
    )
)

body = []
for i, filter_index in enumerate(np.flip(np.argsort(filter_influences))):
    weights = filter_weights[filter_index]
    weights = weights - np.mean(weights, axis=1, keepdims=True)
    if np.sum(weights[:, [0, 2]]) < 0.5 * np.sum(weights):
        # Flip to purine-rich version
        weights = np.flip(weights, axis=(0, 1))
    fig = viz_sequence.plot_weights(weights, figsize=(20, 4), return_fig=True)
    fig.tight_layout()
    
    body.append(
        vdomh.tr(
            vdomh.td(str(i + 1)),
            vdomh.td(str(filter_index)),
            vdomh.td("%.3f" % filter_influences[filter_index]),
            vdomh.td(figure_to_vdom_image(fig))
        )
    )
    
    if activation_motifs_cache_dir:
        # Save motif PWM
        fig.savefig(os.path.join(activation_motifs_cache_dir, "filter_weight_motif_%d.png" % filter_index))

display(vdomh.table(colgroup, header, vdomh.tbody(*body)))
plt.close("all")
RankFilter indexInfluenceMean-normalized filter weights
1320.030
260.021
380.017
4520.015
5440.012
6370.011
7540.011
8260.010
9400.009
10590.009
11600.008
12140.007
1320.007
14530.007
15450.006
16170.006
17430.006
18280.006
19550.005
20150.005
2150.005
22240.005
23200.005
2400.005
2590.005
26410.004
27620.004
28100.004
29130.004
30220.004
31420.004
32330.003
33360.003
34180.003
35580.003
3670.003
37470.003
38340.003
39460.003
40500.003
41570.002
42210.002
43190.002
44560.002
45230.002
46120.002
47300.002
48310.002
4940.002
50290.002
5110.002
52390.002
53160.002
54610.002
55110.002
56630.001
57380.001
58480.001
59510.001
60490.001
61350.001
62250.001
63270.001
6430.000