In [1]:
%load_ext autoreload
%autoreload 2
%reset -f
import os
import util
import moods
import h5py
import viz_sequence
import numpy as np
import pandas as pd
import pomegranate
import sklearn.cluster
import scipy.cluster.hierarchy
import scipy.stats
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import subprocess
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
/users/anusri/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:21: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
Out[1]:
<tqdm.notebook.tqdm_notebook at 0x7f27c2285650>
In [2]:
# Plotting defaults
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.weight": "bold"
}
plt.rcParams.update(plot_params)

Define constants and paths

In [3]:
# Define parameters/fetch arguments
tfm_results_path = os.environ["TFM_TFM_PATH"]
shap_scores_path = os.environ["TFM_SHAP_PATH"]
peak_bed_paths = [os.environ["TFM_PEAKS_PATH"]]
moods_dir = os.environ["TFM_MOODS_DIR"]
reference_fasta = os.environ["TFM_REFERENCE_PATH"]
tomtom_path = os.environ["TFM_TOMTOM_PATH"]

#tfm_results_path = "/oak/stanford/groups/akundaje/projects/chrombpnet_paper_new/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/SIGNAL/modisco_crop_500/modisco_results_allChroms_counts.hdf5"
#shap_scores_path = "/mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret/merged.HEPG2.counts_scores.h5"
#peak_bed_paths = ["/mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/ATAC_PE/HEPG2/nautilus_runs_jun16/HEPG2_05.09.2022_bias_128_4_1234_0.8_fold_0/interpret/merged.HEPG2.interpreted_regions.bed"]
#moods_dir = "/mnt/lab_data3/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/06_22_2022_motif_scanning/moods_baseairmodels/"
#reference_fasta = "/mnt/lab_data2/anusri/chrombpnet/reference/hg38.genome.fa"
#tomtom_path="/oak/stanford/groups/akundaje/projects/chrombpnet_paper_new/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/SIGNAL/modisco_crop_500/counts.tomtom.tsv"



print("TF-MoDISco results path: %s" % tfm_results_path)
print("TF-MoDISco tomtom results path: %s" % tomtom_path)
print("DeepSHAP scores path: %s" % shap_scores_path)
print("Peaks path: %s" % peak_bed_paths[0])
print("MOODS directory: %s" % moods_dir)
print("Reference genome path: %s" % reference_fasta)
TF-MoDISco results path: /oak/stanford/groups/akundaje/projects/chrombpnet_paper_new/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/SIGNAL/modisco_crop_500/modisco_results_allChroms_counts.hdf5
TF-MoDISco tomtom results path: /oak/stanford/groups/akundaje/projects/chrombpnet_paper_new/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/SIGNAL/modisco_crop_500/counts.tomtom.tsv
DeepSHAP scores path: /mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret//merged.HEPG2.counts_scores.h5
Peaks path: /mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret//merged.HEPG2.interpreted_regions.bed
MOODS directory: /mnt/lab_data3/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/06_22_2022_motif_scanning/moods_baseairmodels/
Reference genome path: /mnt/lab_data2/anusri/chrombpnet/reference/hg38.genome.fa
In [4]:
# Constants
input_length = 2114
hyp_score_key = "hyp_scores"
motif_fdr_cutoff = 0.10

Helper functions

For plotting and organizing things

In [5]:
def import_tfmodisco_motifs(tfm_results_path, trim=True, only_pos=True):
    """
    Imports the PFMs to into a dictionary, mapping `(x, y)` to the PFM,
    where `x` is the metacluster index and `y` is the pattern index.
    Arguments:
        `tfm_results_path`: path to HDF5 containing TF-MoDISco results
        `out_dir`: where to save motifs
        `trim`: if True, trim the motif flanks based on information content
        `only_pos`: if True, only return motifs with positive contributions
    Returns the dictionary of PFMs.
    """ 
    pfms = {}
    with h5py.File(tfm_results_path, "r") as f:
        metaclusters = f["metacluster_idx_to_submetacluster_results"]
        num_metaclusters = len(metaclusters.keys())
        for metacluster_i, metacluster_key in enumerate(metaclusters.keys()):
            metacluster = metaclusters[metacluster_key]
            if "patterns" not in metacluster["seqlets_to_patterns_result"]:
                continue
            patterns = metacluster["seqlets_to_patterns_result"]["patterns"]
            num_patterns = len(patterns["all_pattern_names"][:])
            for pattern_i, pattern_name in enumerate(patterns["all_pattern_names"][:]):
                pattern_name = pattern_name.decode()
                pattern = patterns[pattern_name]
                pfm = pattern["sequence"]["fwd"][:]
                cwm = pattern["task0_contrib_scores"]["fwd"][:]
                
                # Check that the contribution scores are overall positive
                if only_pos and np.sum(cwm) < 0:
                    continue
                    
                if trim:
                    pfm = util.trim_motif(pfm, pfm)
                    
                pfms["%d_%d" % (metacluster_i,pattern_i)] = pfm
    return pfms
In [6]:
def estimate_mode(x_values, bins=200, levels=1):
    """
    Estimates the mode of the distribution using `levels`
    iterations of histograms.
    """
    hist, edges = np.histogram(x_values, bins=bins)
    bin_mode = np.argmax(hist)
    left_edge, right_edge = edges[bin_mode], edges[bin_mode + 1]
    if levels <= 1:
        return (left_edge + right_edge) / 2
    else:
        return estimate_mode(
            x_values[(x_values >= left_edge) & (x_values < right_edge)],
            bins=bins,
            levels=(levels - 1)
        )
In [7]:
def fit_tight_exponential_dist(x_values, mode=0, percentiles=np.arange(0.05, 1, 0.05)):
    """
    Given an array of x-values and a set of percentiles of the distribution,
    computes the set of lambda values for an exponential distribution if the
    distribution were fit to each percentile of the x-values. Returns an array
    of lambda values parallel to `percentiles`. The exponential distribution
    is assumed to have the given mean/mode, and all data less than this mode
    is tossed out when doing this computation.
    """
    assert np.min(percentiles) >= 0 and np.max(percentiles) <= 1
    x_values = x_values[x_values >= mode]
    per_x_vals = np.percentile(x_values, percentiles * 100)
    return -np.log(1 - percentiles) / (per_x_vals - mode)
In [8]:
def exponential_pdf(x_values, lamb):
    return lamb * np.exp(-lamb * x_values)
def exponential_cdf(x_values, lamb):
    return 1 - np.exp(-lamb * x_values)
In [9]:
def filter_peak_hits_by_fdr(hit_table, fdr_cutoff=0.05):
    """
    Filters the table of peak hits (as imported by `moods.import_moods_hits`)
    by the importance score fraction by fitting a mixture model to the score
    distribution, taking the exponential component, and then fitting a
    percentile-tightened exponential distribution to this component.
    p-values are computed using this null, and then the FDR-cutoff is applied
    using Benjamini-Hochberg.
    Returns a reduced hit table of the same format. This will also generate
    plots for the score distribution and the FDR cutoffs.
    """
    scores = hit_table["imp_frac_score"].values
    scores_finite = scores[np.isfinite(scores)]
    
    mode = estimate_mode(scores_finite)

    # Fit mixture of models to scores (mode-shifted)
    over_mode_scores = scores_finite[scores_finite >= mode] - mode
    mixed_model = pomegranate.GeneralMixtureModel.from_samples(
        [
            pomegranate.ExponentialDistribution,
            pomegranate.NormalDistribution,
            pomegranate.NormalDistribution
        ],
        3, over_mode_scores[:, None]
    )
    mixed_model = mixed_model.fit(over_mode_scores)
    mixed_model_exp_dist = mixed_model.distributions[0]
    
    # Obtain a distribution of scores that belong to the exponential distribution
    exp_scores = over_mode_scores[mixed_model.predict(over_mode_scores[:, None]) == 0]
    
    # Fit a tight exponential distribution based on percentiles
    lamb = np.max(fit_tight_exponential_dist(exp_scores))
    
    # Plot score distribution and fit
    
    fig, ax = plt.subplots(nrows=3, figsize=(20, 20))

    x = np.linspace(np.min(scores_finite), np.max(scores_finite), 200)[1:]  # Skip first bucket (it's usually very large
    mix_dist_pdf = mixed_model.probability(x)
    mixed_model_exp_dist_pdf = mixed_model_exp_dist.probability(x)

    perc_dist_pdf = exponential_pdf(x, lamb)
    perc_dist_cdf = exponential_cdf(x, lamb)

    # Plot mixed model
    ax[0].hist(over_mode_scores + mode, bins=500, density=True, alpha=0.3)
    ax[0].axvline(mode)
    ax[0].plot(x + mode, mix_dist_pdf, label="Mixed model")
    ax[0].plot(x + mode, mixed_model_exp_dist_pdf, label="Exponential component")
    ax[0].legend()

    # Plot fitted PDF
    ax[1].hist(exp_scores, bins=500, density=True, alpha=0.3)
    ax[1].plot(x + mode, perc_dist_pdf, label="Percentile-fitted")

    # Plot fitted CDF
    ax[2].hist(exp_scores, bins=500, density=True, alpha=1, cumulative=True, histtype="step")
    ax[2].plot(x + mode, perc_dist_cdf, label="Percentile-fitted")

    ax[0].set_title("Motif hit scores")
    plt.show()
    
    # Compute p-values
    score_range = np.linspace(np.min(scores_finite), np.max(scores_finite), 1000000)
    inverse_cdf = 1 - exponential_cdf(score_range, lamb)
    assignments = np.digitize(scores - mode, score_range, right=True)
    assignments[~np.isfinite(scores)] = 0  # If score was NaN, give it a p-value of ~1
    pvals = inverse_cdf[assignments]
    pvals_sorted = np.sort(pvals)

    # Plot FDR cut-offs of various levels
    fdr_levels = [0.05, 0.1, 0.2, 0.3]
    pval_threshes = []
    fig, ax = plt.subplots(figsize=(20, 8))
    ranks = np.arange(1, len(pvals_sorted) + 1)
    ax.plot(ranks, pvals_sorted, color="black", label="p-values")
    for fdr in fdr_levels:
        bh_crit_vals = ranks / len(ranks) * fdr
        ax.plot(ranks, bh_crit_vals, label=("Crit values (FDR = %.2f)" % fdr))
        inds = np.where(pvals_sorted <= bh_crit_vals)[0]
        if not len(inds):
            pval_threshes.append(-1)
        else:
            pval_threshes.append(pvals_sorted[np.max(inds)])
    ax.set_title("Step-up p-values and FDR corrective critical values")
    plt.legend()
    plt.show()
    
    # Show table of number of hits at each FDR level
    header = vdomh.thead(
        vdomh.tr(
            vdomh.th("FDR level", style={"text-align": "center"}),
            vdomh.th("Number of hits kept", style={"text-align": "center"}),
            vdomh.th("% hits kept", style={"text-align": "center"})
        )
    )
    rows = []
    for i, fdr in enumerate(fdr_levels):
        num_kept = np.sum(pvals <= pval_threshes[i])
        frac_kept = num_kept / len(pvals)
        rows.append(vdomh.tr(
            vdomh.td("%.2f" % fdr), vdomh.td("%d" % num_kept), vdomh.td("%.2f%%" % (frac_kept * 100))
        ))
    body = vdomh.tbody(*rows)
    display(vdomh.table(header, body))

    # Perform filtering
    bh_crit_vals = fdr_cutoff * ranks / len(ranks)
    inds = np.where(pvals_sorted <= bh_crit_vals)[0]
    if not len(inds):
        pval_thresh = -1
    else:
        pval_thresh = pvals_sorted[np.max(inds)]
    return hit_table.iloc[pvals <= pval_thresh]
In [10]:
def get_peak_hits(peak_table, hit_table):
    """
    For each peak, extracts the set of motif hits that fall in that peak.
    Returns a list mapping peak index to a subtable of `hit_table`. The index
    of the list is the index of the peak table.
    """
    peak_hits = [pd.DataFrame(columns=list(hit_table),  dtype=object)] * len(peak_table)
    for peak_index, matches in tqdm.notebook.tqdm(hit_table.groupby("peak_index")):
        # Check that all of the matches are indeed overlapping the peak
        peak_row = peak_table.iloc[peak_index]
        chrom, start, end = peak_row["chrom"], peak_row["peak_start"], peak_row["peak_end"]
        assert np.all(matches["chrom"] == chrom)
        assert np.all((matches["start"] < end) & (start < matches["end"]))
        
        peak_hits[peak_index] = matches
    return peak_hits
In [11]:
def get_peak_motif_counts(peak_hits, motif_keys):
    """
    From the peak hits (as returned by `get_peak_hits`), computes a count
    array of size N x M, where N is the number of peaks and M is the number of
    motifs. Each entry represents the number of times a motif appears in a peak.
    `motif_keys` is a list of motif keys as they appear in `peak_hits`; the
    order of the motifs M matches this list.
    """
    motif_inds = {motif_keys[i] : i for i in range(len(motif_keys))}
    counts = np.zeros((len(peak_hits), len(motif_keys)), dtype=int)
    for i in tqdm.notebook.trange(len(peak_hits)):
        hits = peak_hits[i]
        for key, num in zip(*np.unique(hits["key"], return_counts=True)):
            counts[i][motif_inds[key]] = num
    return counts
In [12]:
def cluster_matrix_indices(matrix, num_clusters):
    """
    Clusters matrix using k-means. Always clusters on the first
    axis. Returns the indices needed to optimally order the matrix
    by clusters.
    """
    if len(matrix) == 1:
        # Don't cluster at all
        return np.array([0])

    num_clusters = min(num_clusters, len(matrix))
    
    # Perform k-means clustering
    kmeans = sklearn.cluster.KMeans(n_clusters=num_clusters)
    cluster_assignments = kmeans.fit_predict(matrix)

    # Perform hierarchical clustering on the cluster centers to determine optimal ordering
    kmeans_centers = kmeans.cluster_centers_
    cluster_order = scipy.cluster.hierarchy.leaves_list(
        scipy.cluster.hierarchy.optimal_leaf_ordering(
            scipy.cluster.hierarchy.linkage(kmeans_centers, method="centroid"), kmeans_centers
        )
    )

    # Order the peaks so that the cluster assignments follow the optimal ordering
    cluster_inds = []
    for cluster_id in cluster_order:
        cluster_inds.append(np.where(cluster_assignments == cluster_id)[0])
    cluster_inds = np.concatenate(cluster_inds)
    return cluster_inds
In [13]:
def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_keys):
    """
    Plots a simple indicator heatmap of the motifs in each peak.
    """
    peak_hit_indicators = (peak_hit_counts > 0).astype(int)
    # Cluster matrix by peaks
    inds = cluster_matrix_indices(peak_hit_indicators, max(5, len(peak_hit_indicators) // 10))
    matrix = peak_hit_indicators[inds]
    
    # Cluster matrix by motifs
    matrix_t = np.transpose(matrix)
    inds = cluster_matrix_indices(matrix_t, max(5, len(matrix_t) // 4))
    matrix = np.transpose(matrix_t[inds])
    motif_keys = np.array(motif_keys)[inds]

    # Create a figure with the right dimensions
    fig_height = min(len(peak_hit_indicators) * 0.004, 8)
    fig, ax = plt.subplots(figsize=(16, fig_height))

    # Plot the heatmap
    ax.imshow(matrix, interpolation="nearest", aspect="auto", cmap="Greens")

    # Set axes on heatmap
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.set_xticks(np.arange(len(motif_keys)))
    ax.set_xticklabels(motif_keys)
    ax.set_xlabel("Motif")

    fig.tight_layout()
    plt.show()
In [14]:
def plot_homotypic_densities(peak_hit_counts, motif_keys):
    """
    Plots a CDF of number of motif hits per peak, for each motif.
    """
    for i in range(len(motif_keys)):
        counts = peak_hit_counts[:, i]
        
        fig, ax = plt.subplots(figsize=(8, 8))
        bins = np.concatenate([np.arange(np.max(counts)), [np.inf]])
        ax.hist(counts, bins=bins, density=True, histtype="step", cumulative=True)
        ax.set_title("Cumulative distribution of number of %s hits per peak" % motif_keys[i])
        ax.set_xlabel("Number of motifs k in peak")
        ax.set_ylabel("Proportion of peaks with at least k motifs")
        plt.show()
In [15]:
def get_motif_cooccurrence_count_matrix(peak_hit_counts):
    """
    From an N x M (peaks by motifs) array of hit counts, returns
    an M x M array of counts (i.e. how many times two motifs occur
    together in the same peak). For the diagonal entries, we require
    that motif occur at least twice in a peak to be counted.
    """
    peak_hit_indicators = (peak_hit_counts > 0).astype(int)
    num_motifs = peak_hit_indicators.shape[1]
    count_matrix = np.zeros((num_motifs, num_motifs), dtype=int)
    for i in range(num_motifs):
        for j in range(i):
            pair_col = np.sum(peak_hit_indicators[:, [i, j]], axis=1)
            count = np.sum(pair_col == 2)
            count_matrix[i, j] = count
            count_matrix[j, i] = count
        count_matrix[i, i] = np.sum(peak_hit_counts[:, i] >= 2)
    return count_matrix
In [16]:
def compute_cooccurrence_pvals(peak_hit_counts):
    """
    Given the number of motif hits in each peak, computes p-value of
    co-occurrence for each pair of motifs, including self pairs.
    Returns an M x M array of p-values for the M motifs.
    """
    peak_hit_indicators = (peak_hit_counts > 0).astype(int)
    num_peaks, num_motifs = peak_hit_counts.shape
    
    pvals = np.ones((num_motifs, num_motifs))
    
    # Significance is based on a Fisher's exact test. If the motifs were
    # present in peaks randomly, we'd independence of occurrence.
    # For self-co-occurrence, the null model is not independence, but
    # collisions
    for i in range(num_motifs):
        for j in range(i):
            pair_counts = peak_hit_indicators[:, [i, j]]
            peaks_with_1 = pair_counts[:, 0] == 1
            peaks_with_2 = pair_counts[:, 1] == 1
            # Contingency table (universe is set of all peaks):
            #              no motif 1  |  has motif 1
            # no motif 2       A       |      B
            # -------------------------+--------------
            # has motif 2      C       |      D
            # The Fisher's exact test evaluates the significance of the
            # association between the two classifications
            cont_table = np.array([
                [
                    np.sum(~(peaks_with_1) & (~peaks_with_2)),
                    np.sum(peaks_with_1 & (~peaks_with_2))
                ],
                [
                    np.sum(~(peaks_with_1) & peaks_with_2),
                    np.sum(peaks_with_1 & peaks_with_2)
                ]
            ])
            pval = scipy.stats.fisher_exact(
                cont_table, alternative="greater"
            )[1]
            pvals[i, j] = pval
            pvals[j, i] = pval

        # Self-co-occurrence: Poissonize balls in bins
        # Expected number of collisions (via linearity of expectations):
        num_hits = np.sum(peak_hit_indicators[:, i])  # number of "balls"
        expected_collisions = num_hits * (num_hits - 1) / (2 * num_peaks)
        num_collisions = np.sum(peak_hit_counts[:, i] >= 2)
        pval = 1 - scipy.stats.poisson.cdf(num_collisions, mu=expected_collisions)
        pvals[i, i] = pval
    
    return pvals
In [17]:
def plot_motif_cooccurrence_heatmaps(count_matrix, pval_matrix, motif_keys):
    """
    Plots a heatmap showing the number of peaks that have both types of
    each motif, as well as a heatmap showing the p-value of co-occurrence.
    """
    assert count_matrix.shape == pval_matrix.shape
    num_motifs = pval_matrix.shape[0]
    assert len(motif_keys) == num_motifs

    # Cluster by p-value
    inds = cluster_matrix_indices(pval_matrix, max(5, num_motifs // 4))
    pval_matrix = pval_matrix[inds][:, inds]
    count_matrix = count_matrix[inds][:, inds]
    motif_keys = np.array(motif_keys)[inds]
    
    # Plot the p-value matrix

    fig_width = max(5, num_motifs)
    fig, ax = plt.subplots(figsize=(fig_width, fig_width))
    
    # Replace 0s with minimum value (we'll label them properly later)
    zero_mask = pval_matrix == 0
    min_val = np.min(pval_matrix[~zero_mask])
    pval_matrix[zero_mask] = min_val
    logpval_matrix = -np.log10(pval_matrix)
    
    hmap = ax.imshow(logpval_matrix)

    ax.set_xticks(np.arange(num_motifs))
    ax.set_yticks(np.arange(num_motifs))
    ax.set_xticklabels(motif_keys, rotation=90)
    ax.set_yticklabels(motif_keys)

    # Loop over data dimensions and create text annotations.
    for i in range(num_motifs):
        for j in range(num_motifs):
            if zero_mask[i, j]:
                text = "Inf"
            else:
                text = "%.2f" % np.abs(logpval_matrix[i, j])
            ax.text(j, i, text, ha="center", va="center")
    fig.colorbar(hmap, orientation="horizontal")

    ax.set_title("-log(p) significance of peaks with both motifs")
    fig.tight_layout()
    plt.show()
    
    # Plot the counts matrix

    fig_width = max(5, num_motifs)
    fig, ax = plt.subplots(figsize=(fig_width, fig_width))
    
    hmap = ax.imshow(count_matrix)

    ax.set_xticks(np.arange(num_motifs))
    ax.set_yticks(np.arange(num_motifs))
    ax.set_xticklabels(motif_keys, rotation=90)
    ax.set_yticklabels(motif_keys)

    # Loop over data dimensions and create text annotations.
    for i in range(num_motifs):
        for j in range(num_motifs):
            ax.text(j, i, count_matrix[i, j], ha="center", va="center")
    fig.colorbar(hmap, orientation="horizontal")

    ax.set_title("Number of peaks with both motifs")
    fig.tight_layout()
    plt.show()
In [18]:
def create_violin_plot(ax, dist_list, colors):
    """
    Creates a violin plot on the given instantiated axes.
    `dist_list` is a list of vectors. `colors` is a parallel
    list of colors for each violin.
    """
    num_perfs = len(dist_list)

    q1, med, q3 = np.stack([
        np.nanpercentile(data, [25, 50, 70], axis=0) for data in dist_list
    ], axis=1)
    iqr = q3 - q1
    lower_outlier = q1 - (1.5 * iqr)
    upper_outlier = q3 + (1.5 * iqr)


    sorted_clipped_data = [  # Remove outliers based on outlier rule
        np.sort(vec[(vec >= lower_outlier[i]) & (vec <= upper_outlier[i])])
        for i, vec in enumerate(dist_list)
    ]

    plot_parts = ax.violinplot(
        sorted_clipped_data, showmeans=False, showmedians=False, showextrema=False
    )
    violin_parts = plot_parts["bodies"]
    for i in range(num_perfs):
        violin_parts[i].set_facecolor(colors[i])
        violin_parts[i].set_edgecolor(colors[i])
        violin_parts[i].set_alpha(0.7)

    inds = np.arange(1, num_perfs + 1)
    ax.vlines(inds, q1, q3, color="black", linewidth=5, zorder=1)
    ax.scatter(inds, med, marker="o", color="white", s=30, zorder=2)
In [19]:
def plot_intermotif_distance_violins(peak_hits, motif_keys, pair_inds):
    """
    For each pair of motifs, plots a violin of distances beween
    motifs. 
    """
    # First, compute the distribution of distances for each pair
    distances = []
    for i, j in tqdm.notebook.tqdm(pair_inds):
        dists = []
        for k in range(len(peak_hits)):
            hits = peak_hits[k]

            hits_1 = hits[hits["key"] == motif_keys[i]]
            hits_2 = hits[hits["key"] == motif_keys[j]]

            if hits_1.empty or hits_2.empty:
                continue

            pos_1 = np.array(hits_1["start"])
            pos_2 = np.array(hits_2["start"])

            len_1 = (hits_1["end"] - hits_1["start"]).values[0]
            len_2 = (hits_2["end"] - hits_2["start"]).values[0]

            # Differences beteween all pairs of positions
            diffs = pos_2[None] - pos_1[:, None]
            # Take minimum distance for each instance of motif 2, but only
            # if the distance is an appropriate length
            for row in diffs:
                row = row[row != 0]
                if not row.size:
                    continue
                dist = row[np.argmin(np.abs(row))]
                if (dist < 0 and dist < -len_2) or (dist > 0 and dist > len_1):
                    dists.append(dist)
        dists = np.array(dists)
        if not dists.size:
            continue
        distances.append(np.abs(dists))  # Take absolute value of distance
    
    if not distances:
        print("No significantly co-occurring motifs")
        return
    
    # Plot the violins
    fig, ax = plt.subplots(figsize=(int(1.7 * len(pair_inds)), 8))
    create_violin_plot(ax, distances, ["mediumorchid"] * len(pair_inds))
    ax.set_title("Distance distributions between motif instances")
    ax.set_xticks(np.arange(1, len(pair_inds) + 1))
    ax.set_xticklabels(["%s/%s" % (motif_keys[i], motif_keys[j]) for i, j in pair_inds], rotation=90)
    plt.show()

Import hit results

In [ ]:
 
In [20]:
# Import the PFMs
print(tfm_results_path) 
pfms = import_tfmodisco_motifs(tfm_results_path)
motif_keys = list(pfms.keys())
/oak/stanford/groups/akundaje/projects/chrombpnet_paper_new/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/SIGNAL/modisco_crop_500/modisco_results_allChroms_counts.hdf5
In [21]:
motif_keys 
Out[21]:
['0_0',
 '0_1',
 '0_2',
 '0_3',
 '0_4',
 '0_5',
 '0_6',
 '0_7',
 '0_8',
 '0_9',
 '0_10',
 '0_11',
 '0_12',
 '0_13',
 '0_14',
 '0_15',
 '0_16',
 '0_17',
 '0_18',
 '0_19',
 '0_20',
 '0_21',
 '0_22',
 '0_23',
 '0_24',
 '0_25',
 '0_26',
 '0_27',
 '0_28',
 '0_29',
 '0_30',
 '0_31']
In [22]:
# Import peaks

peak_table = util.import_peak_table(peak_bed_paths)

print(peak_table.head())
# Expand to input length
peak_table["peak_start"] = \
    (peak_table["peak_start"] + peak_table["summit_offset"]) - (input_length // 2)
peak_table["peak_end"] = peak_table["peak_start"] + input_length
  chrom  peak_start   peak_end         name  score strand   signal      pval  \
0  chr1   100027088  100027511  Peak_148646    681      .  5.21938  68.17744   
1  chr1   100036823  100037403  Peak_159959    592      .  2.48719  59.27093   
2  chr1   100036823  100037403  Peak_162990    570      .  2.45537  57.05564   
3  chr1   100036823  100037403  Peak_249038    240      .  1.87774  24.09786   
4  chr1   100037575  100038997  Peak_291677    175      .  1.72480  17.52787   

       qval  summit_offset     summit  
0  66.18713            312  100027400  
1  57.31506            487  100037310  
2  55.10878            117  100036940  
3  22.34334            330  100037153  
4  15.84093             63  100037638  
In [23]:
# Import DeepSHAP scores

print(shap_scores_path)
hyp_scores, act_scores, one_hot_seqs, shap_coords = util.import_shap_scores_custom(
    shap_scores_path, peak_table, center_cut_size=input_length
)
/mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret//merged.HEPG2.counts_scores.h5
/mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret//merged.HEPG2.counts_scores.h5
466481
(466481, 2114, 4)
In [24]:
print(peak_bed_paths)
# Run MOODS; import results if they already exist
hits_path = os.path.join(moods_dir, "moods_filtered_collapsed_scored.bed")
if os.path.exists(hits_path) and os.stat(hits_path).st_size:
    hit_table = moods.import_moods_hits(hits_path)
else:
    hit_table = moods.get_moods_hits(
        pfms, reference_fasta, peak_bed_paths[0], shap_scores_path, peak_table,
        expand_peak_length=input_length, temp_dir=moods_dir
    )
['/mnt/lab_data2/anusri/chrombpnet/results/chrombpnet/DNASE_PE/HEPG2/HEPG2_06.08.2022_bias_128_4_1234_0.8_fold_0/interpret//merged.HEPG2.interpreted_regions.bed']
In [25]:
hit_table.head()
Out[25]:
chrom start end key strand score peak_index imp_frac_score
0 chr1 10471 10486 0_0 + 6.387017 24449 0.001503
1 chr1 11223 11238 0_0 + 13.861909 24449 0.000635
2 chr1 11281 11296 0_0 + 14.077542 24449 0.000328
3 chr1 11340 11355 0_0 + 7.514150 24449 0.000139
4 chr1 11402 11421 0_0 + 7.783216 24449 0.000064
In [26]:
tomtom = pd.read_csv(tomtom_path, sep="\t")
label_dict = {}
for index,row in tomtom.iterrows():
    keyd = row['Pattern'].replace("metacluster_","").replace("pattern_","").replace(".","_")
    label_dict[keyd] = row['Match_1']


hit_table['key'] = hit_table['key'].apply(lambda x: label_dict[x] if x in label_dict else x)
In [27]:
hit_table.head()

temp_keys = []
for keyd in motif_keys:
    if keyd in label_dict:
        temp_keys.append(label_dict[keyd])
    else:
        temp_keys.append(keyd)

motif_keys = temp_keys
In [28]:
# Filter motif hit table by p-value using FDR estimation
hit_table_filtered = filter_peak_hits_by_fdr(hit_table, fdr_cutoff=motif_fdr_cutoff)
# Save the results back to MOODS directory
filtered_hits_path = os.path.join(moods_dir, "moods_filtered_collapsed_scored_thresholded.bed")
hit_table_filtered.to_csv(filtered_hits_path, sep="\t", header=False, index=False)
FDR levelNumber of hits kept% hits kept
0.0562360613.90%
0.1071681115.98%
0.2084719918.89%
0.3095649521.32%
In [29]:
# Match peaks to motif hits
peak_hits = get_peak_hits(peak_table, hit_table_filtered)
In [30]:
# Construct count array of peaks and hits

peak_hit_counts = get_peak_motif_counts(peak_hits, motif_keys)
In [31]:
peak_hit_counts.shape
Out[31]:
(466481, 32)
In [32]:
# Construct count matrix of motif co-occurrence
motif_cooccurrence_count_matrix = get_motif_cooccurrence_count_matrix(peak_hit_counts)
In [33]:
# Construct the matrix of p-values for motif co-occurrence
motif_cooccurrence_pval_matrix = compute_cooccurrence_pvals(peak_hit_counts)

Proportion of peaks with hits

In [34]:
motifs_per_peak = np.array([len(hits) for hits in peak_hits])
In [35]:
display(vdomh.p("Number of peaks: %d" % len(peak_table)))
display(vdomh.p("Number of motif hits before FDR filtering: %d" % len(hit_table)))
display(vdomh.p("Number of motif hits after FDR filtering: %d" % len(hit_table_filtered)))

Number of peaks: 466481

Number of motif hits before FDR filtering: 4485404

Number of motif hits after FDR filtering: 716811

In [36]:
display(vdomh.p("Number of peaks with 0 motif hits: %d" % np.sum(motifs_per_peak == 0)))

Number of peaks with 0 motif hits: 192799

In [37]:
quants = [0, 0.25, 0.50, 0.75, 0.99, 1]
header = vdomh.thead(
    vdomh.tr(
        vdomh.th("Quantile", style={"text-align": "center"}),
        vdomh.th("Number of hits/peak", style={"text-align": "center"})
    )
)
body = vdomh.tbody(*([
    vdomh.tr(
        vdomh.td("%.1f%%" % (q * 100)), vdomh.td("%d" % v)
    ) for q, v in zip(quants, np.quantile(motifs_per_peak, quants))
]))
vdomh.table(header, body)
Out[37]:
QuantileNumber of hits/peak
0.0%0
25.0%0
50.0%1
75.0%2
99.0%8
100.0%27
In [38]:
fig, ax = plt.subplots(figsize=(10, 10))
bins = np.concatenate([np.arange(np.max(motifs_per_peak) + 1), [np.inf]])
ax.hist(motifs_per_peak, bins=bins, density=True, histtype="step", cumulative=True)
ax.set_title("Cumulative distribution of number of motif hits per peak")
ax.set_xlabel("Number of motifs k in peak")
ax.set_ylabel("Proportion of peaks with at least k motifs")
plt.show()
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
In [39]:
frac_peaks_with_motif = np.sum(peak_hit_counts > 0, axis=0) / len(peak_hit_counts)
labels = np.array(motif_keys)
sorted_inds = np.flip(np.argsort(frac_peaks_with_motif))
frac_peaks_with_motif = frac_peaks_with_motif[sorted_inds]
labels = labels[sorted_inds]

fig, ax = plt.subplots(figsize=(20, 8))
ax.bar(np.arange(len(labels)), frac_peaks_with_motif)
ax.set_title("Proportion of peaks with each motif")
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
plt.xticks(rotation=90)
plt.show()

Examples of motif hits in sequences

In [40]:
# Show some examples of sequences with motif hits
num_to_draw = 3
unique_counts = np.sort(np.unique(motifs_per_peak))
motif_nums = []
if 0 in motifs_per_peak:
    motif_nums.append(0)
if 1 in motifs_per_peak:
    motif_nums.append(1)
motif_nums.extend([
    unique_counts[0],  # Minimum
    unique_counts[len(unique_counts) // 2],  # Median
    unique_counts[-1],  # Maximum
])

for motif_num in np.sort(np.unique(motif_nums)):
    display(vdomh.h4("Sequences with %d motif hits" % motif_num))
    
    peak_inds = np.where(motifs_per_peak == motif_num)[0]
    table_rows = []
    for i in np.random.choice(
        peak_inds, size=min(num_to_draw, len(peak_inds)), replace=False
    ):
        peak_coord = peak_table.iloc[i][["chrom", "peak_start", "peak_end"]].values
        motif_hits = peak_hits[i]
        
        chrom, peak_start, peak_end = peak_coord
        peak_len = peak_end - peak_start
        mask = (shap_coords[:, 0] == chrom) & (shap_coords[:, 1] <= peak_start) & (shap_coords[:, 2] >= peak_end)
        if not np.sum(mask):
            fig = "No matching input sequence found"
            table_rows.append(
                vdomh.tr(
                    vdomh.td("%s:%d-%d" % (chrom, peak_start, peak_end)),
                    vdomh.td(fig)
                )
            )
            continue
            
        seq_index = np.where(mask)[0][0]  # Pick one
        imp_scores = act_scores[seq_index]
        _, seq_start, seq_end = shap_coords[seq_index]
        
        highlights = []
        for _, row in motif_hits.iterrows():
            start = row["start"] - peak_start
            end = start + (row["end"] - row["start"])
            highlights.append((start, end))
        
        # Remove highlights that overrun the sequence
        highlights = [(a, b) for a, b in highlights if a >= 0 and b < peak_len]
        
        start = peak_start - seq_start 
        end = start + peak_len
        imp_scores_peak = imp_scores[start:end]
        
        fig = viz_sequence.plot_weights(
            imp_scores_peak, subticks_frequency=(len(imp_scores_peak) + 1),
            highlight={"red" : [pair for pair in highlights]},
            return_fig=True
        )
        fig = util.figure_to_vdom_image(fig)
        
        table_rows.append(
            vdomh.tr(
                vdomh.td("%s:%d-%d" % (chrom, peak_start, peak_end)),
                vdomh.td(fig)
            )
        )

    table = vdomh.table(*table_rows)
    display(table)
    plt.close("all")

Sequences with 0 motif hits

chr12:6334416-6336530
chr3:186002227-186004341
chr14:90758070-90760184

Sequences with 1 motif hits

chr7:138244831-138246945
chr3:194554165-194556279
chr17:31445697-31447811

Sequences with 10 motif hits

chr22:41918554-41920668
chr6:151361941-151364055
chr19:3868258-3870372

Sequences with 27 motif hits

chr5:180258118-180260232

Homotypic motif densities

For each motif, show how many the motif occurs in each peak

In [41]:
plot_homotypic_densities(peak_hit_counts, motif_keys)
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/anusri/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6653: RuntimeWarning: invalid value encountered in multiply
  tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]

Co-occurrence of motifs

Proportion of time that motifs co-occur with each other in peaks

In [42]:
#plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_keys)
motif_keys
Out[42]:
['CTCF_MA0139.1',
 'HNF4G_MA0484.1',
 'FOXM1_HUMAN.H11MO.0.A',
 'ATF3_MOUSE.H11MO.0.A',
 'CEBPB_MOUSE.H11MO.0.A',
 'FOXA1_MOUSE.H11MO.0.A',
 'KLF12_HUMAN.H11MO.0.C',
 'HNF1B_MA0153.2',
 'SOX9_HUMAN.H11MO.0.B',
 'Gabpa_MA0062.2',
 'NFIA_HUMAN.H11MO.0.C',
 'TCF7L2_MA0523.1',
 'TEAD1_MOUSE.H11MO.0.A',
 'NFYC_HUMAN.H11MO.0.A',
 'RXRA_MOUSE.H11MO.0.A',
 'COT1_MOUSE.H11MO.0.B',
 'FOXD2_forkhead_1',
 'FOXB1_MA0845.1',
 'NRF1_HUMAN.H11MO.0.A',
 'ZN143_MOUSE.H11MO.0.A',
 'FOSL2+JUND_MA1145.1',
 'ARNTL_bHLH_1',
 'TEAD4_MOUSE.H11MO.0.A',
 'FOXA1_MOUSE.H11MO.0.A',
 'RARA_MOUSE.H11MO.0.A',
 'ZBTB33_MA0527.1',
 'CEBPD_MOUSE.H11MO.0.B',
 'FOXM1_HUMAN.H11MO.0.A',
 'RFX3_MOUSE.H11MO.0.C',
 'REST_HUMAN.H11MO.0.A',
 'FOXA1_MA0148.3',
 'GATA6_MOUSE.H11MO.0.A']
In [43]:
plot_motif_cooccurrence_heatmaps(motif_cooccurrence_count_matrix, motif_cooccurrence_pval_matrix, motif_keys)

Distribution of distances between motifs

When motifs co-occur, show the distance between the instances

In [44]:
# # Get which pairs of motifs are significant
pvals, sig_pairs = [], []
for i in range(len(motif_keys)):
     for j in range(i + 1):
        if motif_cooccurrence_pval_matrix[i, j] < 1e-6:
            sig_pairs.append((i, j))
            pvals.append(motif_cooccurrence_pval_matrix[i, j])
inds = np.argsort(pvals)
sig_pairs = [sig_pairs[i] for i in inds]
In [45]:
#plot_intermotif_distance_violins(peak_hits, motif_keys, sig_pairs)
In [ ]: