%load_ext autoreload
%autoreload 2
%reset -f
import os
import sys
sys.path.append(os.path.abspath("/users/amtseng/tfmodisco/src/"))
from tfmodisco.run_tfmodisco import import_shap_scores
from motif.read_motifs import trim_motif_by_ic, pfm_to_pwm, pfm_info_content
from motif.moods import import_moods_hits
from motif.tfmodisco_hit_scoring import import_tfmodisco_hits
from util import figure_to_vdom_image, import_peak_table
import plot.viz_sequence as viz_sequence
from modisco.util import compute_per_position_ic, cpu_sliding_window_sum
import h5py
import numpy as np
import pandas as pd
import pomegranate
import sklearn.cluster
import scipy.cluster.hierarchy
import scipy.stats
import sklearn.isotonic
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import vdom.helpers as vdomh
from IPython.display import display
import tqdm
tqdm.tqdm_notebook()
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:27: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
<tqdm.notebook.tqdm_notebook at 0x7f835e5b5750>
# Plotting defaults
font_manager.fontManager.ttflist.extend(
font_manager.createFontList(
font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
)
)
plot_params = {
"figure.titlesize": 22,
"axes.titlesize": 22,
"axes.labelsize": 20,
"legend.fontsize": 18,
"font.size": 13,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"font.family": "Roboto",
"font.weight": "bold"
}
plt.rcParams.update(plot_params)
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarning: The createFontList function was deprecated in Matplotlib 3.2 and will be removed two minor releases later. Use FontManager.addfont instead. after removing the cwd from sys.path.
# Define parameters/fetch arguments
tf_name = os.environ["TFM_TF_NAME"]
tfm_results_path = os.environ["TFM_TFM_PATH"]
shap_scores_path = os.environ["TFM_SHAP_PATH"]
hyp_score_key = os.environ["TFM_HYP_SCORE_KEY"]
if "TFM_TASK_INDEX" in os.environ:
task_index = int(os.environ["TFM_TASK_INDEX"])
else:
task_index = None
if "TFM_PEAKS" in os.environ:
# If provided, this overrides the peaks defined by the TF name and task index
peak_bed_paths = os.environ["TFM_PEAKS"].split(",")
else:
peak_bed_paths = []
motif_hits_path = os.environ["TFM_HITS_PATH"]
if "TFM_HITS_CACHE" in os.environ:
hits_cache_dir = os.environ["TFM_HITS_CACHE"]
else:
hits_cache_dir = None
print("TF name: %s" % tf_name)
print("TF-MoDISco results path: %s" % tfm_results_path)
print("DeepSHAP scores path: %s" % shap_scores_path)
print("Importance score key: %s" % hyp_score_key)
print("Task index: %s" % task_index)
print("Peak BED paths: %s" % ",".join(peak_bed_paths))
print("Motif hits path: %s" % motif_hits_path)
print("Saved motif hits cache: %s" % hits_cache_dir)
TF name: MAX TF-MoDISco results path: /users/amtseng/tfmodisco/results/tfmodisco/multitask_profile_finetune/MAX_multitask_profile_finetune_fold1/MAX_multitask_profile_finetune_fold1_count_tfm.h5 DeepSHAP scores path: /users/amtseng/tfmodisco/results/importance_scores/multitask_profile_finetune/MAX_multitask_profile_finetune_fold1/MAX_multitask_profile_finetune_fold1_imp_scores.h5 Importance score key: count_hyp_scores Task index: None Peak BED paths: Motif hits path: /users/amtseng/tfmodisco/results/tfmodisco_hit_scoring/multitask_profile_finetune/MAX_multitask_profile_finetune_fold1_count/tfm_matches_collapsed-all.bed Saved motif hits cache: /users/amtseng/tfmodisco/results/reports/motif_hits//cache/tfm/multitask_profile_finetune/MAX_multitask_profile_finetune_fold1/MAX_multitask_profile_finetune_fold1_count
# Constants
input_length = 2114 if "TFM_INPUT_LEN" not in os.environ else int(os.environ["TFM_INPUT_LEN"])
motif_moods_imp_perc_cutoff = 0.10 # For MOODS hits
motif_tfm_imp_prob_cutoff = 0.5 # For TF-MoDISco hits
motif_tfm_save_imp_prob_cutoff = 0.99 # For TF-MoDISco hits
motif_tfm_sim_prob_cutoff = 0.8 # For TF-MoDISco hits
seed = 20210412
# Paths to original called peaks
if not peak_bed_paths:
# Use TF name and task index
base_path = "/users/amtseng/tfmodisco/"
data_path = os.path.join(base_path, "data/processed/ENCODE/")
labels_path = os.path.join(data_path, "labels/%s" % tf_name)
all_peak_beds = sorted([item for item in os.listdir(labels_path) if item.endswith(".bed.gz")])
if task_index is None:
peak_bed_paths = [os.path.join(labels_path, item) for item in all_peak_beds]
else:
peak_bed_paths = [os.path.join(labels_path, all_peak_beds[task_index])]
if hits_cache_dir:
os.makedirs(hits_cache_dir, exist_ok=True)
For plotting and organizing things
def import_tfmodisco_motifs(tfm_results_path, only_pos=True):
"""
Imports the PFMs to into a dictionary, mapping `(x, y)` to the PFM,
where `x` is the metacluster index and `y` is the pattern index.
Arguments:
`tfm_results_path`: path to HDF5 containing TF-MoDISco results
`only_pos`: if True, only return motifs with positive contributions
Returns the dictionary of PFMs, CWMs, and hCWMs.
"""
pfms, cwms, hcwms = {}, {}, {}
with h5py.File(tfm_results_path, "r") as f:
metaclusters = f["metacluster_idx_to_submetacluster_results"]
num_metaclusters = len(metaclusters.keys())
for metacluster_i, metacluster_key in enumerate(metaclusters.keys()):
metacluster = metaclusters[metacluster_key]
if "patterns" not in metacluster["seqlets_to_patterns_result"]:
continue
patterns = metacluster["seqlets_to_patterns_result"]["patterns"]
num_patterns = len(patterns["all_pattern_names"][:])
for pattern_i, pattern_name in enumerate(patterns["all_pattern_names"][:]):
pattern_name = pattern_name.decode()
pattern = patterns[pattern_name]
pfm = pattern["sequence"]["fwd"][:]
cwm = pattern["task0_contrib_scores"]["fwd"][:]
hcwm = pattern["task0_hypothetical_contribs"]["fwd"][:]
# Check that the contribution scores are overall positive
if only_pos and np.sum(cwm) < 0:
continue
pfms["%d_%d" % (metacluster_i,pattern_i)] = pfm
cwms["%d_%d" % (metacluster_i,pattern_i)] = cwm
hcwms["%d_%d" % (metacluster_i,pattern_i)] = hcwm
return pfms, cwms, hcwms
def import_motif_hits(motif_hits_path):
"""
Imports the motif hits, which may be an output of MOODS scanning
or TF-MoDISco hit scanning. Depending on the number of columns, the
hits are imported appropriately.
"""
with open(motif_hits_path, "r") as f:
cols = next(f).split("\t")
if len(cols) == 10:
print("MOODS hits")
hit_table = import_moods_hits(motif_hits_path)
elif len(cols) == 16:
print("TF-MoDISco hits")
hit_table = import_tfmodisco_hits(motif_hits_path)
# Sort by aggregate similarity and drop duplicates (by strand)
hit_table = hit_table.sort_values("agg_sim")
hit_table = hit_table.drop_duplicates(["chrom", "start", "end", "peak_index"], keep="last")
else:
raise ValueError("Motif hits file of unknown format/source: %s" % motif_hits_path)
return hit_table
def estimate_mode(x_values, bins=200, levels=1):
"""
Estimates the mode of the distribution using `levels`
iterations of histograms.
"""
hist, edges = np.histogram(x_values, bins=bins)
bin_mode = np.argmax(hist)
left_edge, right_edge = edges[bin_mode], edges[bin_mode + 1]
if levels <= 1:
return (left_edge + right_edge) / 2
else:
return estimate_mode(
x_values[(x_values >= left_edge) & (x_values < right_edge)],
bins=bins,
levels=(levels - 1)
)
def fit_tight_exponential_dist(x_values, mode=0, percentiles=np.arange(0.05, 1, 0.05)):
"""
Given an array of x-values and a set of percentiles of the distribution,
computes the set of lambda values for an exponential distribution if the
distribution were fit to each percentile of the x-values. Returns an array
of lambda values parallel to `percentiles`. The exponential distribution
is assumed to have the given mean/mode, and all data less than this mode
is tossed out when doing this computation.
"""
assert np.min(percentiles) >= 0 and np.max(percentiles) <= 1
x_values = x_values[x_values >= mode]
per_x_vals = np.percentile(x_values, percentiles * 100)
return -np.log(1 - percentiles) / (per_x_vals - mode)
def exponential_pdf(x_values, lamb):
return lamb * np.exp(-lamb * x_values)
def exponential_cdf(x_values, lamb):
return 1 - np.exp(-lamb * x_values)
def filter_moods_peak_hits(hit_table, score_column="imp_frac_score", imp_perc_cutoff=0.05):
"""
Filters the table of peak hits by the score defined by
`score_column` by fitting a mixture model to the score
distribution, taking the exponential component, and then fitting a
percentile-tightened exponential distribution to this component.
The lowest percentile specified by `imp_perc_cutoff` of this null is
cut out. Returns a reduced hit table of the same format, a figure for the score
distribution.
"""
scores = hit_table[score_column].values
scores_finite = scores[np.isfinite(scores)]
mode = estimate_mode(scores_finite)
# Fit mixture of models to scores (mode-shifted)
over_mode_scores = scores_finite[scores_finite >= mode] - mode
mixed_model = pomegranate.GeneralMixtureModel.from_samples(
[
pomegranate.ExponentialDistribution,
pomegranate.NormalDistribution,
pomegranate.NormalDistribution
],
3, over_mode_scores[:, None]
)
mixed_model = mixed_model.fit(over_mode_scores)
mixed_model_exp_dist = mixed_model.distributions[0]
# Obtain a distribution of scores that belong to the exponential distribution
exp_scores = over_mode_scores[mixed_model.predict(over_mode_scores[:, None]) == 0]
# Fit a tight exponential distribution based on percentiles
lamb = np.max(fit_tight_exponential_dist(exp_scores))
# Plot score distribution and fit
score_fig, ax = plt.subplots(nrows=3, figsize=(20, 20))
x = np.linspace(np.min(scores_finite), np.max(scores_finite), 200)[1:] # Skip first bucket (it's usually very large
mix_dist_pdf = mixed_model.probability(x)
mixed_model_exp_dist_pdf = mixed_model_exp_dist.probability(x)
perc_dist_pdf = exponential_pdf(x, lamb)
perc_dist_cdf = exponential_cdf(x, lamb)
thresh = scipy.stats.expon.ppf(imp_perc_cutoff, loc=mode, scale=(1 / lamb))
# Plot mixed model
ax[0].hist(over_mode_scores + mode, bins=500, density=True, alpha=0.3)
ax[0].axvline(mode)
ax[0].plot(x + mode, mix_dist_pdf, label="Mixed model")
ax[0].plot(x + mode, mixed_model_exp_dist_pdf, label="Exponential component")
ax[0].legend()
# Plot fitted PDF
ax[1].hist(exp_scores, bins=500, density=True, alpha=0.3)
ax[1].plot(x + mode, perc_dist_pdf, label="Percentile-fitted")
ax[1].axvline(thresh)
# Plot fitted CDF
ax[2].hist(exp_scores, bins=500, density=True, alpha=1, cumulative=True, histtype="step")
ax[2].plot(x + mode, perc_dist_cdf, label="Percentile-fitted")
ax[0].set_title("Motif hit scores")
plt.show()
return hit_table.loc[hit_table[score_column] >= thresh].reset_index(drop=True), score_fig
def get_imp_score_dist(
act_scores, window_length, score_type="imp_frac_score", center_cut_size=400, sample=10000
):
"""
Computes the set of importance scores as a fraction or sum of absolute
or signed importance, using windows of the given length. Focuses on the central
bases defined by `center_cut_size`. Returns a NumPy array of values.
`act_scores` is an N x L x 4 array.
"""
assert score_type in ("imp_total_signed_score", "imp_frac_signed_score", "imp_total_score", "imp_frac_score")
start = (act_scores.shape[1] // 2) - (center_cut_size // 2)
end = start + center_cut_size
cut_scores = np.sum(act_scores, axis=2)[:, start:end] # Shape: N x L'
if score_type == "imp_total_signed_score":
scores = cut_scores
elif score_type == "imp_total_score":
scores = np.abs(cut_scores)
elif score_type == "imp_frac_signed_score":
totals = np.sum(cut_scores, axis=1, keepdims=True)
scores = np.divide(
cut_scores, totals,
out=np.zeros_like(cut_scores), where=(totals != 0)
)
else:
abs_cut_scores = np.abs(cut_scores)
totals = np.sum(abs_cut_scores, axis=1, keepdims=True)
scores = np.divide(
abs_cut_scores, totals,
out=np.zeros_like(abs_cut_scores), where=(totals != 0)
)
window_sums = scipy.signal.correlate(scores, np.ones((1, window_length)), mode="valid")
return np.random.choice(np.ravel(window_sums), size=min(sample, window_sums.size), replace=False)
def l1_normalize(seq, axis=None):
if axis is None:
total = np.sum(np.abs(seq))
return seq if not total else seq / total
else:
total = np.sum(np.abs(seq), axis=axis, keepdims=True)
return np.divide(seq, total, out=np.zeros_like(seq), where=(total != 0))
def dot_product_vec(query_seq, target_seqs, normalize=True, revcomp=True):
"""
Takes an I x D query seq and N x I x D target seqs, and computes similarities
them, using a simple dot product. Returns an N-array.
If `revcomp` is True, takes the reverse complement independently for each and
returns the maximum.
"""
query_seq = np.expand_dims(query_seq, axis=0) # 1 x I x D
# L1-normalize
if normalize:
query_seq = l1_normalize(query_seq)
target_seqs = l1_normalize(target_seqs)
sim = np.sum(query_seq * target_seqs, axis=(1, 2))
if revcomp:
query_seq_rc = np.empty_like(query_seq)
for i in range(query_seq.shape[2] // 4):
query_seq_rc[:, :, (i * 4):((i + 1) * 4)] = np.flip(query_seq[:, :, (i * 4):((i + 1) * 4)])
rc_sim = np.sum(query_seq_rc * target_seqs, axis=(1, 2))
return np.maximum(sim, rc_sim)
else:
return sim
def get_motif_score_cosine_sim_dist(
act_scores, cwm, hyp_scores=None, hcwm=None, sample=-1, window_inds=None,
normalize=True, center_cut_size=400
):
"""
Computes a sample of cosine similarities between the CWM and the
actual importance scores. If `hyp_scores` and `hcwm` are also provided, then
computes the similarity between the actual/hypothetical scores and motifs,
concatenated along the bases dimension. `act_scores` and `hyp_scores`
are an N x L x 4 array. Samples only `sample` windows. If `window_inds` is
given, it must be an M x 2 array of sequence indices and window indices
to sample from (out of the original N and L); otherwise, random windows
are sampled from the central `center_cut_size` of the tracks.
Returns a NumPy array of values.
"""
if hyp_scores is not None:
assert act_scores.shape == hyp_scores.shape
assert hcwm is not None
motif_len = cwm.shape[0]
if window_inds is None:
if sample > 0:
num_samples = min(sample, act_scores.shape[0] * (center_cut_size - motif_len + 1))
seq_inds = np.random.choice(act_scores.shape[0], size=num_samples)
window_starts = np.random.choice(center_cut_size - motif_len + 1, size=num_samples)
window_starts = window_starts + (act_scores.shape[1] // 2) - (center_cut_size // 2)
else:
seq_inds = np.repeat(np.arange(act_scores.shape[0]), center_cut_size - motif_len + 1)
window_starts = np.tile(np.arange(center_cut_size - motif_len + 1), act_scores.shape[0])
else:
if sample > 0:
num_samples = min(sample, len(window_inds))
sample_inds = np.random.choice(len(window_inds), size=num_samples, replace=False)
seq_inds = window_inds[sample_inds, 0]
window_starts = window_inds[sample_inds, 1]
else:
seq_inds = window_inds[:, 0]
window_starts = window_inds[:, 1]
act_windows = act_scores[
np.expand_dims(seq_inds, axis=1),
np.linspace(window_starts, window_starts + motif_len - 1, motif_len, axis=1).astype(int)
]
if normalize:
act_windows = l1_normalize(act_windows, axis=(1, 2))
cwm = l1_normalize(cwm)
if hyp_scores is not None:
hyp_windows = hyp_scores[
np.expand_dims(seq_inds, axis=1),
np.linspace(window_starts, window_starts + motif_len - 1, motif_len, axis=1).astype(int)
]
if normalize:
hyp_windows = l1_normalize(hyp_windows, axis=(1, 2))
hcwm = l1_normalize(hcwm)
windows = np.concatenate([act_windows, hyp_windows], axis=2)
motif = np.concatenate([cwm, hcwm], axis=1)
else:
windows = act_windows
motif = cwm
return dot_product_vec(motif, windows, normalize=False)
def get_ic_scaled_motif_score_sim_dist(
act_scores, pfm, hcwm, sample=-1, window_inds=None,
normalize=True, center_cut_size=400
):
"""
Computes a sample of dot-product similarities between the hCWM and the
actual importance scores. The hCWM is scaled by information content,
based on the given PFM. `act_scores` is an N x L x 4 array. Samples
only `sample` windows. If `window_inds` is given, it must be an M x 2
array of sequence indices and window indices to sample from (out of the
original N and L); otherwise, random windows are sampled from the central
`center_cut_size` of the tracks.
Returns a NumPy array of values.
"""
motif_len = hcwm.shape[0]
if window_inds is None:
if sample > 0:
num_samples = min(sample, act_scores.shape[0] * (center_cut_size - motif_len + 1))
seq_inds = np.random.choice(act_scores.shape[0], size=num_samples)
window_starts = np.random.choice(center_cut_size - motif_len + 1, size=num_samples)
window_starts = window_starts + (act_scores.shape[1] // 2) - (center_cut_size // 2)
else:
seq_inds = np.repeat(np.arange(act_scores.shape[0]), center_cut_size - motif_len + 1)
window_starts = np.tile(np.arange(center_cut_size - motif_len + 1), act_scores.shape[0])
else:
if sample > 0:
num_samples = min(sample, len(window_inds))
sample_inds = np.random.choice(len(window_inds), size=num_samples, replace=False)
seq_inds = window_inds[sample_inds, 0]
window_starts = window_inds[sample_inds, 1]
else:
seq_inds = window_inds[:, 0]
window_starts = window_inds[:, 1]
act_windows = act_scores[
np.expand_dims(seq_inds, axis=1),
np.linspace(window_starts, window_starts + motif_len - 1, motif_len, axis=1).astype(int)
]
if normalize:
act_windows = l1_normalize(act_windows, axis=(1, 2))
hcwm = l1_normalize(hcwm)
hcwm = np.expand_dims(pfm_info_content(pfm), axis=1) * hcwm
return dot_product_vec(hcwm, act_windows, normalize=False)
def filter_tfm_peak_hits(
hit_table, shap_coords, act_scores, cwms, imp_score_column="imp_frac_score",
imp_thresh=0.5, sim_thresh=0.99, save_imp_thresh=0.8
):
"""
Filters the table of peak hits. Filters based on importance and
similarity. `imp_score_column` defines the importance score column
to filter on. `shap_coords` is an N x 3 object array denoting coordinates
of importance scores. The `peak_index` column of `hit_table` must index
into these coordinates. `act_scores` is a parallel N x L x 4 array of
actual importance scores. `cwms` is a dictionary mapping motif keys to
CWMs, and must match the motifs used by the motif hit scorer exactly.
Returns a reduced hit table of the same format, and a dictionary of
figures of the distributions used for filtering, one for each motif key.
"""
np.random.seed(seed)
filter_mask = np.zeros(len(hit_table), dtype=bool) # All False
motif_lengths = dict(zip(hit_table["key"], hit_table["end"] - hit_table["start"]))
motif_keys = sorted(motif_lengths.keys())
filter_figs = {}
for motif_key in motif_keys:
cwm = cwms[motif_key]
motif_hit_table = hit_table[hit_table["key"] == motif_key]
filter_fig, ax = plt.subplots(nrows=4, figsize=(20, 16))
# Importance
hit_imp_scores = motif_hit_table[imp_score_column].values
hit_imp_scores_finite = hit_imp_scores[np.isfinite(hit_imp_scores)]
bg_imp_scores = get_imp_score_dist(act_scores, motif_lengths[motif_key], score_type=imp_score_column, sample=len(hit_imp_scores))
bg_imp_scores_finite = bg_imp_scores[np.isfinite(bg_imp_scores)]
x = np.linspace(0, 1.0, 2000) # Restrict to positive and under 1
reg_x = np.concatenate([bg_imp_scores_finite, hit_imp_scores_finite])
reg_y = np.concatenate([np.zeros(len(bg_imp_scores_finite)), np.ones(len(hit_imp_scores_finite))])
iso_reg_model = sklearn.isotonic.IsotonicRegression()
iso_reg_model.fit(reg_x, reg_y)
iso_preds = iso_reg_model.predict(x)
pass_inds = np.where(iso_preds >= 0.5)[0]
imp_thresh = x[np.min(pass_inds)] if pass_inds.size else np.max(hit_imp_scores_finite)
pass_inds = np.where(iso_preds >= 0.8)[0]
save_imp_thresh = x[np.min(pass_inds)] if pass_inds.size else np.max(hit_imp_scores_finite)
ax[0].scatter(np.clip(bg_imp_scores_finite, 0, 1), np.zeros(len(bg_imp_scores_finite)), alpha=0.05, label="Background importance scores")
ax[0].scatter(np.clip(hit_imp_scores_finite, 0, 1), np.ones(len(hit_imp_scores_finite)), alpha=0.05, label="Hit importance scores")
ax[0].plot(x, iso_preds)
ax[0].axvline(imp_thresh)
ax[0].axvline(save_imp_thresh)
ymin, ymax = ax[0].get_ylim()
ax[0].annotate("%f" % imp_thresh, xy=(imp_thresh, (ymin + ymax) * 0.25))
ax[0].annotate("%f" % save_imp_thresh, xy=(save_imp_thresh, (ymin + ymax) * 0.75))
ax[0].set_title(imp_score_column)
ax[0].legend()
ax[1].hist(bg_imp_scores_finite, bins=x, density=True, alpha=0.3, label="Background importance scores")
ax[1].hist(hit_imp_scores_finite, bins=x, density=True, alpha=0.3, label="Hit importance scores")
ax[1].axvline(imp_thresh)
ax[1].axvline(save_imp_thresh)
ymin, ymax = ax[1].get_ylim()
ax[1].annotate("%f" % imp_thresh, xy=(imp_thresh, (ymin + ymax) * 0.25))
ax[1].annotate("%f" % save_imp_thresh, xy=(save_imp_thresh, (ymin + ymax) * 0.75))
ax[1].legend()
# Similarity
window_inds = np.empty((len(motif_hit_table), 2), dtype=int)
window_inds[:, 0] = motif_hit_table["peak_index"].values
window_inds[:, 1] = motif_hit_table["start"].values - shap_coords[:, 1][window_inds[:, 0]]
hit_sim_scores = get_motif_score_cosine_sim_dist(act_scores, cwm, window_inds=window_inds)
hit_sim_scores_finite = hit_sim_scores[np.isfinite(hit_sim_scores)]
bg_sim_scores = get_motif_score_cosine_sim_dist(act_scores, cwm, sample=len(hit_sim_scores))
bg_sim_scores_finite = bg_sim_scores[np.isfinite(bg_sim_scores)]
x = np.linspace(
min(np.min(bg_sim_scores_finite), np.min(hit_sim_scores_finite)),
max(np.max(bg_sim_scores_finite), np.max(hit_sim_scores_finite)), 2000
)
reg_x = np.concatenate([bg_sim_scores_finite, hit_sim_scores_finite])
reg_y = np.concatenate([np.zeros(len(bg_sim_scores_finite)), np.ones(len(hit_sim_scores_finite))])
iso_reg_model = sklearn.isotonic.IsotonicRegression()
iso_reg_model.fit(reg_x, reg_y)
iso_preds = iso_reg_model.predict(x)
pass_inds = np.where(iso_preds >= 0.99)[0]
sim_thresh = x[np.min(pass_inds)] if pass_inds.size else 1
ax[2].scatter(bg_sim_scores_finite, np.zeros(len(bg_sim_scores_finite)), alpha=0.05, label="Background similarity scores")
ax[2].scatter(hit_sim_scores_finite, np.ones(len(hit_sim_scores_finite)), alpha=0.05, label="Hit similarity scores")
ax[2].plot(x, iso_preds)
ax[2].axvline(sim_thresh)
ymin, ymax = ax[3].get_ylim()
ax[2].annotate("%f" % sim_thresh, xy=(sim_thresh, (ymin + ymax) * 0.5))
ax[2].set_title("Actual cosine similarity")
ax[2].legend()
ax[3].hist(bg_sim_scores_finite, bins=500, density=True, alpha=0.3, label="Background similarity scores")
ax[3].hist(hit_sim_scores_finite, bins=500, density=True, alpha=0.3, label="Hit similarity scores")
ax[3].axvline(sim_thresh)
ymin, ymax = ax[3].get_ylim()
ax[3].annotate("%f" % sim_thresh, xy=(sim_thresh, (ymin + ymax) * 0.5))
ax[3].legend()
filter_fig.suptitle("Filtering motif %s" % motif_key)
filter_figs[motif_key] = filter_fig
plt.show()
motif_mask = \
((hit_sim_scores >= sim_thresh) | (hit_imp_scores >= save_imp_thresh)) & \
(motif_hit_table[imp_score_column] >= imp_thresh)
filter_mask[hit_table["key"] == motif_key] = motif_mask
# Show statistics on how many motifs were kept
orig_num = len(motif_hit_table)
orig_peak_num = len(np.unique(motif_hit_table["peak_index"]))
new_num = np.sum(motif_mask)
new_peak_num = len(np.unique(motif_hit_table.loc[motif_mask]["peak_index"]))
num_peaks = 1 + np.max(hit_table["peak_index"])
print("Hit number reduction: %d -> %d (%f)" % (orig_num, new_num, (new_num - orig_num) / orig_num))
print("Proportion of peaks reduction: %f -> %f" % (orig_peak_num / num_peaks, new_peak_num / num_peaks))
return hit_table.loc[filter_mask], filter_figs
def get_peak_hits(peak_table, hit_table):
"""
For each peak, extracts the set of motif hits that fall in that peak.
Returns a list mapping peak index to a subtable of `hit_table`. The index
of the list is the index of the peak table.
"""
peak_hits = [pd.DataFrame(columns=list(hit_table))] * len(peak_table)
for peak_index, matches in tqdm.notebook.tqdm(hit_table.groupby("peak_index")):
# Check that all of the matches are indeed overlapping the peak
peak_row = peak_table.iloc[peak_index]
chrom, start, end = peak_row["chrom"], peak_row["peak_start"], peak_row["peak_end"]
assert np.all(matches["chrom"] == chrom)
assert np.all((matches["start"] < end) & (start < matches["end"]))
peak_hits[peak_index] = matches
return peak_hits
def get_peak_motif_counts(peak_hits, motif_keys):
"""
From the peak hits (as returned by `get_peak_hits`), computes a count
array of size N x M, where N is the number of peaks and M is the number of
motifs. Each entry represents the number of times a motif appears in a peak.
`motif_keys` is a list of motif keys as they appear in `peak_hits`; the
order of the motifs M matches this list.
"""
motif_inds = {motif_keys[i] : i for i in range(len(motif_keys))}
counts = np.zeros((len(peak_hits), len(motif_keys)), dtype=int)
for i in tqdm.notebook.trange(len(peak_hits)):
hits = peak_hits[i]
for key, num in zip(*np.unique(hits["key"], return_counts=True)):
counts[i][motif_inds[key]] = num
return counts
def cluster_matrix_indices(matrix, num_clusters):
"""
Clusters matrix using k-means. Always clusters on the first
axis. Returns the indices needed to optimally order the matrix
by clusters.
"""
if len(matrix) == 1:
# Don't cluster at all
return np.array([0])
num_clusters = min(num_clusters, len(matrix))
# Perform k-means clustering
kmeans = sklearn.cluster.MiniBatchKMeans(n_clusters=num_clusters)
cluster_assignments = kmeans.fit_predict(matrix)
# Perform hierarchical clustering on the cluster centers to determine optimal ordering
kmeans_centers = kmeans.cluster_centers_
cluster_order = scipy.cluster.hierarchy.leaves_list(
scipy.cluster.hierarchy.optimal_leaf_ordering(
scipy.cluster.hierarchy.linkage(kmeans_centers, method="centroid"), kmeans_centers
)
)
# Order the peaks so that the cluster assignments follow the optimal ordering
cluster_inds = []
for cluster_id in cluster_order:
cluster_inds.append(np.where(cluster_assignments == cluster_id)[0])
cluster_inds = np.concatenate(cluster_inds)
return cluster_inds
def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_keys, subsample=None):
"""
Plots a simple indicator heatmap of the motifs in each peak.
Returns the figure.
"""
# Subsample peaks
if subsample:
peak_hit_counts = peak_hit_counts[np.random.choice(
len(peak_hit_counts), size=min(len(peak_hit_counts), subsample), replace=False
)]
peak_hit_indicators = (peak_hit_counts > 0).astype(int)
# Order columns by prevalence (by number of peaks with that motif)
counts = np.sum(peak_hit_indicators, axis=0)
inds = np.flip(np.argsort(counts))
matrix = peak_hit_indicators[:, inds]
motif_keys = np.array(motif_keys)[inds]
# Order rows in "binary" order
places = np.power(2, np.flip(np.arange(matrix.shape[1])))
values = np.sum(matrix * places, axis=1)
inds = np.flip(np.argsort(values))
matrix = matrix[inds]
# Create a figure with the right dimensions
fig_height = min(len(peak_hit_indicators) * 0.004, 8)
fig, ax = plt.subplots(figsize=(16, fig_height))
# Plot the heatmap
ax.imshow(matrix, interpolation="nearest", aspect="auto", cmap="Greens")
# Set axes on heatmap
ax.set_yticks([])
ax.set_yticklabels([])
ax.set_xticks(np.arange(len(motif_keys)))
ax.set_xticklabels(motif_keys)
ax.set_xlabel("Motif")
fig.tight_layout()
plt.show()
return fig
def plot_homotypic_densities(peak_hit_counts, motif_keys):
"""
Plots a CDF of number of motif hits per peak, for each motif.
Returns a dictionary mapping motif key to figure.
"""
figs = {}
for i in range(len(motif_keys)):
counts = peak_hit_counts[:, i]
fig, ax = plt.subplots(figsize=(8, 8))
bins = np.concatenate([np.arange(np.max(counts)), [np.inf]])
ax.hist(counts, bins=bins, density=True, histtype="step", cumulative=True)
ax.set_title("Cumulative distribution of number of %s hits per peak" % motif_keys[i])
ax.set_xlabel("Number of motifs k in peak")
ax.set_ylabel("Proportion of peaks with at least k motifs")
plt.show()
figs[motif_keys[i]] = fig
return figs
def get_motif_cooccurrence_count_matrix(peak_hit_counts):
"""
From an N x M (peaks by motifs) array of hit counts, returns
an M x M array of counts (i.e. how many times two motifs occur
together in the same peak). For the diagonal entries, we require
that motif occur at least twice in a peak to be counted.
"""
peak_hit_indicators = (peak_hit_counts > 0).astype(int)
num_motifs = peak_hit_indicators.shape[1]
count_matrix = np.zeros((num_motifs, num_motifs), dtype=int)
for i in range(num_motifs):
for j in range(i):
pair_col = np.sum(peak_hit_indicators[:, [i, j]], axis=1)
count = np.sum(pair_col == 2)
count_matrix[i, j] = count
count_matrix[j, i] = count
count_matrix[i, i] = np.sum(peak_hit_counts[:, i] >= 2)
return count_matrix
def compute_cooccurrence_pvals(peak_hit_counts):
"""
Given the number of motif hits in each peak, computes p-value of
co-occurrence for each pair of motifs, including self pairs.
Returns an M x M array of p-values for the M motifs.
"""
peak_hit_indicators = (peak_hit_counts > 0).astype(int)
num_peaks, num_motifs = peak_hit_counts.shape
pvals = np.ones((num_motifs, num_motifs))
# Significance is based on a Fisher's exact test. If the motifs were
# present in peaks randomly, we'd independence of occurrence.
# For self-co-occurrence, the null model is not independence, but
# collisions
for i in range(num_motifs):
for j in range(i):
pair_counts = peak_hit_indicators[:, [i, j]]
peaks_with_1 = pair_counts[:, 0] == 1
peaks_with_2 = pair_counts[:, 1] == 1
# Contingency table (universe is set of all peaks):
# no motif 1 | has motif 1
# no motif 2 A | B
# -------------------------+--------------
# has motif 2 C | D
# The Fisher's exact test evaluates the significance of the
# association between the two classifications
cont_table = np.array([
[
np.sum(~(peaks_with_1) & (~peaks_with_2)),
np.sum(peaks_with_1 & (~peaks_with_2))
],
[
np.sum(~(peaks_with_1) & peaks_with_2),
np.sum(peaks_with_1 & peaks_with_2)
]
])
pval = scipy.stats.fisher_exact(
cont_table, alternative="greater"
)[1]
pvals[i, j] = pval
pvals[j, i] = pval
# Self-co-occurrence: Poissonize balls in bins
# Expected number of collisions (via linearity of expectations):
num_hits = np.sum(peak_hit_indicators[:, i]) # number of "balls"
expected_collisions = num_hits * (num_hits - 1) / (2 * num_peaks)
num_collisions = np.sum(peak_hit_counts[:, i] >= 2)
if num_collisions == 0:
pval = 1
else:
pval = 1 - scipy.stats.poisson.cdf(num_collisions, mu=expected_collisions)
pvals[i, i] = pval
return pvals
def plot_motif_cooccurrence_heatmaps(count_matrix, pval_matrix, motif_keys):
"""
Plots a heatmap showing the number of peaks that have both types of
each motif, as well as a heatmap showing the p-value of co-occurrence.
Returns the p-value figure and the count figure, as well as the indices
of motifs used for clustering.
"""
assert count_matrix.shape == pval_matrix.shape
num_motifs = pval_matrix.shape[0]
assert len(motif_keys) == num_motifs
# Cluster by p-value
inds = cluster_matrix_indices(pval_matrix, max(5, num_motifs // 4))
pval_matrix = pval_matrix[inds][:, inds]
count_matrix = count_matrix[inds][:, inds]
motif_keys = np.array(motif_keys)[inds]
# Plot the p-value matrix
fig_width = max(5, num_motifs)
p_fig, ax = plt.subplots(figsize=(fig_width, fig_width))
# Replace 0s with minimum value (we'll label them properly later)
zero_mask = pval_matrix == 0
non_zeros = pval_matrix[~zero_mask]
if not len(non_zeros):
logpval_matrix = np.tile(np.inf, pval_matrix.shape)
else:
min_val = np.min(pval_matrix[~zero_mask])
pval_matrix[zero_mask] = min_val
logpval_matrix = -np.log10(pval_matrix)
hmap = ax.imshow(logpval_matrix)
ax.set_xticks(np.arange(num_motifs))
ax.set_yticks(np.arange(num_motifs))
ax.set_xticklabels(motif_keys, rotation=45)
ax.set_yticklabels(motif_keys)
# Loop over data dimensions and create text annotations.
for i in range(num_motifs):
for j in range(num_motifs):
if zero_mask[i, j]:
text = "Inf"
else:
text = "%.2f" % np.abs(logpval_matrix[i, j])
ax.text(j, i, text, ha="center", va="center")
p_fig.colorbar(hmap, orientation="horizontal")
ax.set_title("-log(p) significance of peaks with both motifs")
p_fig.tight_layout()
plt.show()
# Plot the counts matrix
fig_width = max(5, num_motifs)
c_fig, ax = plt.subplots(figsize=(fig_width, fig_width))
hmap = ax.imshow(count_matrix)
ax.set_xticks(np.arange(num_motifs))
ax.set_yticks(np.arange(num_motifs))
ax.set_xticklabels(motif_keys, rotation=45)
ax.set_yticklabels(motif_keys)
# Loop over data dimensions and create text annotations.
for i in range(num_motifs):
for j in range(num_motifs):
ax.text(j, i, count_matrix[i, j], ha="center", va="center")
c_fig.colorbar(hmap, orientation="horizontal")
ax.set_title("Number of peaks with both motifs")
c_fig.tight_layout()
plt.show()
return p_fig, c_fig, inds
def create_violin_plot(ax, dist_list, colors):
"""
Creates a violin plot on the given instantiated axes.
`dist_list` is a list of vectors. `colors` is a parallel
list of colors for each violin.
"""
num_perfs = len(dist_list)
q1, med, q3 = np.stack([
np.nanpercentile(data, [25, 50, 70], axis=0) for data in dist_list
], axis=1)
iqr = q3 - q1
lower_outlier = q1 - (1.5 * iqr)
upper_outlier = q3 + (1.5 * iqr)
sorted_clipped_data = [ # Remove outliers based on outlier rule
np.sort(vec[(vec >= lower_outlier[i]) & (vec <= upper_outlier[i])])
for i, vec in enumerate(dist_list)
]
plot_parts = ax.violinplot(
sorted_clipped_data, showmeans=False, showmedians=False, showextrema=False
)
violin_parts = plot_parts["bodies"]
for i in range(num_perfs):
violin_parts[i].set_facecolor(colors[i])
violin_parts[i].set_edgecolor(colors[i])
violin_parts[i].set_alpha(0.7)
inds = np.arange(1, num_perfs + 1)
ax.vlines(inds, q1, q3, color="black", linewidth=5, zorder=1)
ax.scatter(inds, med, marker="o", color="white", s=30, zorder=2)
def plot_intermotif_distance_violins(peak_hits, motif_keys, pair_inds, cluster_inds):
"""
For each pair of motifs, plots a violin of distances beween
motifs. Returns a dictionary mapping pairs of motif keys to arrays
of distances, and the figure.
"""
# First, compute the distribution of distances for each pair
distance_dict = {}
key_pairs = []
for i, j in tqdm.notebook.tqdm(pair_inds):
dists = []
for k in range(len(peak_hits)):
hits = peak_hits[k]
hits_1 = hits[hits["key"] == motif_keys[i]]
hits_2 = hits[hits["key"] == motif_keys[j]]
if hits_1.empty or hits_2.empty:
continue
pos_1 = np.array(hits_1["start"])
pos_2 = np.array(hits_2["start"])
len_1 = (hits_1["end"] - hits_1["start"]).values[0]
len_2 = (hits_2["end"] - hits_2["start"]).values[0]
# Differences beteween all pairs of positions
diffs = pos_2[None] - pos_1[:, None]
# Take minimum distance for each instance of motif 2, but only
# if the distance is an appropriate length
for row in diffs:
row = row[row != 0]
if not row.size:
continue
dist = row[np.argmin(np.abs(row))]
if (dist < 0 and dist < -len_2) or (dist > 0 and dist > len_1):
dists.append(dist)
dists = np.array(dists)
if not dists.size:
continue
key_pair = (motif_keys[i], motif_keys[j])
key_pairs.append(key_pair)
distance_dict[key_pair] = np.abs(dists) # Take absolute value of distance
if not distance_dict:
print("No significantly co-occurring motifs")
return distance_dict, None
# Create the plot
fig, ax = plt.subplots(
nrows=len(motif_keys), ncols=len(motif_keys),
figsize=(len(motif_keys) * 4, len(motif_keys) * 4)
)
if type(ax) is not np.ndarray:
ax = np.array([[ax]])
# Map motif key to axis index
key_to_index = dict(zip(np.array(motif_keys)[cluster_inds], np.arange(len(motif_keys))))
def clean_subplot(ax):
# Do this instead of ax.axis("off"), which would also remove any
# axis labels
ax.set_yticks([])
ax.set_xticks([])
for orient in ("top", "bottom", "left", "right"):
ax.spines[orient].set_visible(False)
# Create violins
for i in range(len(motif_keys)):
for j in range(i, len(motif_keys)):
key_1, key_2 = motif_keys[i], motif_keys[j]
key_pair, rev_key_pair = (key_1, key_2), (key_2, key_1)
axis_1, axis_2 = key_to_index[key_1], key_to_index[key_2]
# Always plot lower triangle
if axis_1 < axis_2:
axis_1, axis_2 = axis_2, axis_1
if key_pair in distance_dict or rev_key_pair in distance_dict:
if rev_key_pair in distance_dict:
key_pair = rev_key_pair
dist = distance_dict[key_pair]
create_violin_plot(ax[axis_1, axis_2], [dist], ["mediumorchid"])
ax[axis_1, axis_2].set_xticks([]) # Remove x-axis labels, as they don't mean much
if axis_1 != axis_2:
# If off diagonal, clean the axes of the symmetric cell
clean_subplot(ax[axis_2, axis_1])
else:
clean_subplot(ax[axis_1, axis_2])
clean_subplot(ax[axis_2, axis_1])
# Make motif labels
for i in range(len(motif_keys)):
ax[i, 0].set_ylabel(motif_keys[cluster_inds[i]])
ax[-1, i].set_xlabel(motif_keys[cluster_inds[i]])
# Remove x-axis labels/ticks
ax[-1, -1].set_xticks([])
fig.suptitle("Distance distributions between co-occurring motifs")
fig.tight_layout(rect=[0, 0.03, 1, 0.98])
return distance_dict, fig
# Import the PFMs, CWMs, and hCWMs
pfms, cwms, hcwms = import_tfmodisco_motifs(tfm_results_path)
motif_keys = list(pfms.keys())
# Import peaks
peak_table = import_peak_table(peak_bed_paths)
# Expand to input length
peak_table["peak_start"] = \
(peak_table["peak_start"] + peak_table["summit_offset"]) - (input_length // 2)
peak_table["peak_end"] = peak_table["peak_start"] + input_length
# Import DeepSHAP scores
hyp_scores, act_scores, one_hot_seqs, shap_coords = import_shap_scores(
shap_scores_path, hyp_score_key, center_cut_size=None, remove_non_acgt=False
)
Importing SHAP scores: 0%| | 0/204 [00:00<?, ?it/s] Importing SHAP scores: 0%| | 1/204 [00:03<11:18, 3.34s/it] Importing SHAP scores: 1%| | 2/204 [00:07<12:12, 3.63s/it] Importing SHAP scores: 1%|▏ | 3/204 [00:07<07:39, 2.28s/it] Importing SHAP scores: 2%|▏ | 4/204 [00:11<09:06, 2.73s/it] Importing SHAP scores: 2%|▏ | 5/204 [00:14<09:26, 2.85s/it] Importing SHAP scores: 3%|▎ | 6/204 [00:15<07:00, 2.12s/it] Importing SHAP scores: 3%|▎ | 7/204 [00:17<07:25, 2.26s/it] Importing SHAP scores: 4%|▍ | 8/204 [00:20<08:08, 2.49s/it] Importing SHAP scores: 4%|▍ | 9/204 [00:21<06:17, 1.94s/it] Importing SHAP scores: 5%|▍ | 10/204 [00:23<06:20, 1.96s/it] Importing SHAP scores: 5%|▌ | 11/204 [00:24<05:05, 1.58s/it] Importing SHAP scores: 6%|▌ | 12/204 [00:26<05:33, 1.74s/it] Importing SHAP scores: 6%|▋ | 13/204 [00:28<05:47, 1.82s/it] Importing SHAP scores: 7%|▋ | 14/204 [00:28<04:42, 1.49s/it] Importing SHAP scores: 7%|▋ | 15/204 [00:31<05:22, 1.71s/it] Importing SHAP scores: 8%|▊ | 16/204 [00:33<06:16, 2.00s/it] Importing SHAP scores: 8%|▊ | 17/204 [00:34<05:01, 1.61s/it] Importing SHAP scores: 9%|▉ | 18/204 [00:36<05:41, 1.83s/it] Importing SHAP scores: 9%|▉ | 19/204 [00:37<04:36, 1.50s/it] Importing SHAP scores: 10%|▉ | 20/204 [00:39<05:18, 1.73s/it] Importing SHAP scores: 10%|█ | 21/204 [00:42<05:56, 1.95s/it] Importing SHAP scores: 11%|█ | 22/204 [00:42<04:46, 1.57s/it] Importing SHAP scores: 11%|█▏ | 23/204 [00:45<05:25, 1.80s/it] Importing SHAP scores: 12%|█▏ | 24/204 [00:48<06:20, 2.12s/it] Importing SHAP scores: 12%|█▏ | 25/204 [00:48<05:05, 1.71s/it] Importing SHAP scores: 13%|█▎ | 26/204 [00:50<05:19, 1.80s/it] Importing SHAP scores: 13%|█▎ | 27/204 [00:53<06:00, 2.04s/it] Importing SHAP scores: 14%|█▎ | 28/204 [00:54<04:49, 1.65s/it] Importing SHAP scores: 14%|█▍ | 29/204 [00:56<05:31, 1.89s/it] Importing SHAP scores: 15%|█▍ | 30/204 [00:57<04:28, 1.54s/it] Importing SHAP scores: 15%|█▌ | 31/204 [00:59<05:04, 1.76s/it] Importing SHAP scores: 16%|█▌ | 32/204 [01:02<05:47, 2.02s/it] Importing SHAP scores: 16%|█▌ | 33/204 [01:03<04:38, 1.63s/it] Importing SHAP scores: 17%|█▋ | 34/204 [01:05<05:07, 1.81s/it] Importing SHAP scores: 17%|█▋ | 35/204 [01:07<05:34, 1.98s/it] Importing SHAP scores: 18%|█▊ | 36/204 [01:08<04:28, 1.60s/it] Importing SHAP scores: 18%|█▊ | 37/204 [01:10<05:08, 1.85s/it] Importing SHAP scores: 19%|█▊ | 38/204 [01:11<04:10, 1.51s/it] Importing SHAP scores: 19%|█▉ | 39/204 [01:13<04:39, 1.69s/it] Importing SHAP scores: 20%|█▉ | 40/204 [01:16<05:29, 2.01s/it] Importing SHAP scores: 20%|██ | 41/204 [01:17<04:25, 1.63s/it] Importing SHAP scores: 21%|██ | 42/204 [01:18<04:35, 1.70s/it] Importing SHAP scores: 21%|██ | 43/204 [01:21<05:14, 1.95s/it] Importing SHAP scores: 22%|██▏ | 44/204 [01:22<04:12, 1.58s/it] Importing SHAP scores: 22%|██▏ | 45/204 [01:24<04:27, 1.68s/it] Importing SHAP scores: 23%|██▎ | 46/204 [01:24<03:39, 1.39s/it] Importing SHAP scores: 23%|██▎ | 47/204 [01:26<04:02, 1.54s/it] Importing SHAP scores: 24%|██▎ | 48/204 [01:28<04:31, 1.74s/it] Importing SHAP scores: 24%|██▍ | 49/204 [01:29<03:41, 1.43s/it] Importing SHAP scores: 25%|██▍ | 50/204 [01:31<04:02, 1.58s/it] Importing SHAP scores: 25%|██▌ | 51/204 [01:33<04:25, 1.74s/it] Importing SHAP scores: 25%|██▌ | 52/204 [01:34<03:37, 1.43s/it] Importing SHAP scores: 26%|██▌ | 53/204 [01:36<03:57, 1.57s/it] Importing SHAP scores: 26%|██▋ | 54/204 [01:38<04:10, 1.67s/it] Importing SHAP scores: 27%|██▋ | 55/204 [01:38<03:26, 1.38s/it] Importing SHAP scores: 27%|██▋ | 56/204 [01:41<04:02, 1.64s/it] Importing SHAP scores: 28%|██▊ | 57/204 [01:41<03:22, 1.38s/it] Importing SHAP scores: 28%|██▊ | 58/204 [01:44<04:27, 1.83s/it] Importing SHAP scores: 29%|██▉ | 59/204 [01:48<05:42, 2.36s/it] Importing SHAP scores: 29%|██▉ | 60/204 [01:49<04:32, 1.89s/it] Importing SHAP scores: 30%|██▉ | 61/204 [01:52<05:25, 2.28s/it] Importing SHAP scores: 30%|███ | 62/204 [01:55<05:58, 2.53s/it] Importing SHAP scores: 31%|███ | 63/204 [01:56<04:42, 2.00s/it] Importing SHAP scores: 31%|███▏ | 64/204 [01:59<05:30, 2.36s/it] Importing SHAP scores: 32%|███▏ | 65/204 [02:00<04:22, 1.89s/it] Importing SHAP scores: 32%|███▏ | 66/204 [02:05<06:23, 2.78s/it] Importing SHAP scores: 33%|███▎ | 67/204 [02:09<07:36, 3.33s/it] Importing SHAP scores: 33%|███▎ | 68/204 [02:10<05:49, 2.57s/it] Importing SHAP scores: 34%|███▍ | 69/204 [02:15<07:06, 3.16s/it] Importing SHAP scores: 34%|███▍ | 70/204 [02:20<08:25, 3.77s/it] Importing SHAP scores: 35%|███▍ | 71/204 [02:21<06:23, 2.88s/it] Importing SHAP scores: 35%|███▌ | 72/204 [02:25<07:08, 3.25s/it] Importing SHAP scores: 36%|███▌ | 73/204 [02:26<05:30, 2.52s/it] Importing SHAP scores: 36%|███▋ | 74/204 [02:30<06:43, 3.11s/it] Importing SHAP scores: 37%|███▋ | 75/204 [02:35<07:54, 3.68s/it] Importing SHAP scores: 37%|███▋ | 76/204 [02:36<05:59, 2.81s/it] Importing SHAP scores: 38%|███▊ | 77/204 [02:40<06:53, 3.25s/it] Importing SHAP scores: 38%|███▊ | 78/204 [02:45<07:54, 3.77s/it] Importing SHAP scores: 39%|███▊ | 79/204 [02:46<05:59, 2.87s/it] Importing SHAP scores: 39%|███▉ | 80/204 [02:52<07:42, 3.73s/it] Importing SHAP scores: 40%|███▉ | 81/204 [02:56<07:47, 3.80s/it] Importing SHAP scores: 40%|████ | 82/204 [02:56<05:53, 2.90s/it] Importing SHAP scores: 41%|████ | 83/204 [03:00<06:31, 3.24s/it] Importing SHAP scores: 41%|████ | 84/204 [03:01<05:00, 2.50s/it] Importing SHAP scores: 42%|████▏ | 85/204 [03:05<05:52, 2.96s/it] Importing SHAP scores: 42%|████▏ | 86/204 [03:10<06:52, 3.49s/it] Importing SHAP scores: 43%|████▎ | 87/204 [03:11<05:13, 2.68s/it] Importing SHAP scores: 43%|████▎ | 88/204 [03:16<06:28, 3.35s/it] Importing SHAP scores: 44%|████▎ | 89/204 [03:21<07:25, 3.88s/it] Importing SHAP scores: 44%|████▍ | 90/204 [03:21<05:36, 2.95s/it] Importing SHAP scores: 45%|████▍ | 91/204 [03:26<06:40, 3.55s/it] Importing SHAP scores: 45%|████▌ | 92/204 [03:27<05:04, 2.72s/it] Importing SHAP scores: 46%|████▌ | 93/204 [03:31<05:28, 2.96s/it] Importing SHAP scores: 46%|████▌ | 94/204 [03:34<05:51, 3.19s/it] Importing SHAP scores: 47%|████▋ | 95/204 [03:35<04:29, 2.47s/it] Importing SHAP scores: 47%|████▋ | 96/204 [03:40<05:35, 3.11s/it] Importing SHAP scores: 48%|████▊ | 97/204 [03:43<05:37, 3.15s/it] Importing SHAP scores: 48%|████▊ | 98/204 [03:44<04:18, 2.44s/it] Importing SHAP scores: 49%|████▊ | 99/204 [03:50<06:05, 3.48s/it] Importing SHAP scores: 49%|████▉ | 100/204 [03:53<05:55, 3.41s/it] Importing SHAP scores: 50%|████▉ | 101/204 [03:54<04:30, 2.62s/it] Importing SHAP scores: 50%|█████ | 102/204 [04:00<06:14, 3.68s/it] Importing SHAP scores: 50%|█████ | 103/204 [04:03<05:59, 3.56s/it] Importing SHAP scores: 51%|█████ | 104/204 [04:07<06:10, 3.70s/it] Importing SHAP scores: 51%|█████▏ | 105/204 [04:10<05:35, 3.39s/it] Importing SHAP scores: 52%|█████▏ | 106/204 [04:11<04:13, 2.58s/it] Importing SHAP scores: 52%|█████▏ | 107/204 [04:14<04:25, 2.74s/it] Importing SHAP scores: 53%|█████▎ | 108/204 [04:18<05:10, 3.24s/it] Importing SHAP scores: 53%|█████▎ | 109/204 [04:19<03:56, 2.49s/it] Importing SHAP scores: 54%|█████▍ | 110/204 [04:22<04:07, 2.64s/it] Importing SHAP scores: 54%|█████▍ | 111/204 [04:23<03:11, 2.06s/it] Importing SHAP scores: 55%|█████▍ | 112/204 [04:26<03:47, 2.48s/it] Importing SHAP scores: 55%|█████▌ | 113/204 [04:30<04:15, 2.81s/it] Importing SHAP scores: 56%|█████▌ | 114/204 [04:30<03:15, 2.18s/it] Importing SHAP scores: 56%|█████▋ | 115/204 [04:33<03:29, 2.35s/it] Importing SHAP scores: 57%|█████▋ | 116/204 [04:37<03:57, 2.70s/it] Importing SHAP scores: 57%|█████▋ | 117/204 [04:37<03:02, 2.10s/it] Importing SHAP scores: 58%|█████▊ | 118/204 [04:40<03:22, 2.36s/it] Importing SHAP scores: 58%|█████▊ | 119/204 [04:41<02:39, 1.88s/it] Importing SHAP scores: 59%|█████▉ | 120/204 [04:44<03:15, 2.33s/it] Importing SHAP scores: 59%|█████▉ | 121/204 [04:48<03:33, 2.57s/it] Importing SHAP scores: 60%|█████▉ | 122/204 [04:48<02:47, 2.04s/it] Importing SHAP scores: 60%|██████ | 123/204 [04:51<03:10, 2.35s/it] Importing SHAP scores: 61%|██████ | 124/204 [04:55<03:35, 2.69s/it] Importing SHAP scores: 61%|██████▏ | 125/204 [04:56<02:47, 2.12s/it] Importing SHAP scores: 62%|██████▏ | 126/204 [04:59<03:04, 2.37s/it] Importing SHAP scores: 62%|██████▏ | 127/204 [05:02<03:18, 2.57s/it] Importing SHAP scores: 63%|██████▎ | 128/204 [05:02<02:34, 2.04s/it] Importing SHAP scores: 63%|██████▎ | 129/204 [05:05<02:53, 2.31s/it] Importing SHAP scores: 64%|██████▎ | 130/204 [05:06<02:17, 1.85s/it] Importing SHAP scores: 64%|██████▍ | 131/204 [05:10<02:51, 2.35s/it] Importing SHAP scores: 65%|██████▍ | 132/204 [05:13<03:10, 2.65s/it] Importing SHAP scores: 65%|██████▌ | 133/204 [05:14<02:28, 2.09s/it] Importing SHAP scores: 66%|██████▌ | 134/204 [05:18<03:02, 2.61s/it] Importing SHAP scores: 66%|██████▌ | 135/204 [05:22<03:25, 2.98s/it] Importing SHAP scores: 67%|██████▋ | 136/204 [05:22<02:37, 2.32s/it] Importing SHAP scores: 67%|██████▋ | 137/204 [05:27<03:24, 3.05s/it] Importing SHAP scores: 68%|██████▊ | 138/204 [05:28<02:44, 2.49s/it] Importing SHAP scores: 68%|██████▊ | 139/204 [05:32<03:06, 2.87s/it] Importing SHAP scores: 69%|██████▊ | 140/204 [05:36<03:18, 3.11s/it] Importing SHAP scores: 69%|██████▉ | 141/204 [05:36<02:31, 2.41s/it] Importing SHAP scores: 70%|██████▉ | 142/204 [05:40<02:51, 2.76s/it] Importing SHAP scores: 70%|███████ | 143/204 [05:43<02:55, 2.87s/it] Importing SHAP scores: 71%|███████ | 144/204 [05:44<02:14, 2.25s/it] Importing SHAP scores: 71%|███████ | 145/204 [05:47<02:27, 2.50s/it] Importing SHAP scores: 72%|███████▏ | 146/204 [05:48<02:03, 2.13s/it] Importing SHAP scores: 72%|███████▏ | 147/204 [05:51<02:13, 2.34s/it] Importing SHAP scores: 73%|███████▎ | 148/204 [05:54<02:25, 2.60s/it] Importing SHAP scores: 73%|███████▎ | 149/204 [05:55<01:53, 2.06s/it] Importing SHAP scores: 74%|███████▎ | 150/204 [05:58<02:03, 2.29s/it] Importing SHAP scores: 74%|███████▍ | 151/204 [06:02<02:30, 2.83s/it] Importing SHAP scores: 75%|███████▍ | 152/204 [06:03<02:00, 2.32s/it] Importing SHAP scores: 75%|███████▌ | 153/204 [06:07<02:14, 2.64s/it] Importing SHAP scores: 75%|███████▌ | 154/204 [06:09<02:10, 2.62s/it] Importing SHAP scores: 76%|███████▌ | 155/204 [06:10<01:41, 2.07s/it] Importing SHAP scores: 76%|███████▋ | 156/204 [06:13<01:58, 2.47s/it] Importing SHAP scores: 77%|███████▋ | 157/204 [06:14<01:32, 1.96s/it] Importing SHAP scores: 77%|███████▋ | 158/204 [06:17<01:48, 2.35s/it] Importing SHAP scores: 78%|███████▊ | 159/204 [06:21<01:57, 2.60s/it] Importing SHAP scores: 78%|███████▊ | 160/204 [06:21<01:30, 2.06s/it] Importing SHAP scores: 79%|███████▉ | 161/204 [06:24<01:41, 2.36s/it] Importing SHAP scores: 79%|███████▉ | 162/204 [06:28<01:50, 2.63s/it] Importing SHAP scores: 80%|███████▉ | 163/204 [06:29<01:26, 2.12s/it] Importing SHAP scores: 80%|████████ | 164/204 [06:31<01:29, 2.23s/it] Importing SHAP scores: 81%|████████ | 165/204 [06:32<01:09, 1.77s/it] Importing SHAP scores: 81%|████████▏ | 166/204 [06:34<01:17, 2.03s/it] Importing SHAP scores: 82%|████████▏ | 167/204 [06:37<01:19, 2.14s/it] Importing SHAP scores: 82%|████████▏ | 168/204 [06:38<01:01, 1.71s/it] Importing SHAP scores: 83%|████████▎ | 169/204 [06:40<01:07, 1.92s/it] Importing SHAP scores: 83%|████████▎ | 170/204 [06:44<01:28, 2.61s/it] Importing SHAP scores: 84%|████████▍ | 171/204 [06:45<01:07, 2.06s/it] Importing SHAP scores: 84%|████████▍ | 172/204 [06:47<01:09, 2.16s/it] Importing SHAP scores: 85%|████████▍ | 173/204 [06:50<01:12, 2.34s/it] Importing SHAP scores: 85%|████████▌ | 174/204 [06:51<01:01, 2.05s/it] Importing SHAP scores: 86%|████████▌ | 175/204 [06:54<01:05, 2.26s/it] Importing SHAP scores: 86%|████████▋ | 176/204 [06:56<01:02, 2.23s/it] Importing SHAP scores: 87%|████████▋ | 177/204 [06:59<01:04, 2.40s/it] Importing SHAP scores: 87%|████████▋ | 178/204 [07:02<01:04, 2.47s/it] Importing SHAP scores: 88%|████████▊ | 179/204 [07:03<00:48, 1.94s/it] Importing SHAP scores: 88%|████████▊ | 180/204 [07:06<00:59, 2.47s/it] Importing SHAP scores: 89%|████████▊ | 181/204 [07:08<00:53, 2.35s/it] Importing SHAP scores: 89%|████████▉ | 182/204 [07:09<00:41, 1.88s/it] Importing SHAP scores: 90%|████████▉ | 183/204 [07:11<00:41, 1.98s/it] Importing SHAP scores: 90%|█████████ | 184/204 [07:12<00:32, 1.64s/it] Importing SHAP scores: 91%|█████████ | 185/204 [07:15<00:35, 1.87s/it] Importing SHAP scores: 91%|█████████ | 186/204 [07:19<00:45, 2.51s/it] Importing SHAP scores: 92%|█████████▏| 187/204 [07:20<00:36, 2.12s/it] Importing SHAP scores: 92%|█████████▏| 188/204 [07:26<00:52, 3.27s/it] Importing SHAP scores: 93%|█████████▎| 189/204 [07:29<00:51, 3.41s/it] Importing SHAP scores: 93%|█████████▎| 190/204 [07:34<00:50, 3.63s/it] Importing SHAP scores: 94%|█████████▎| 191/204 [07:40<00:57, 4.41s/it] Importing SHAP scores: 94%|█████████▍| 192/204 [07:41<00:40, 3.37s/it] Importing SHAP scores: 95%|█████████▍| 193/204 [07:43<00:32, 2.93s/it] Importing SHAP scores: 95%|█████████▌| 194/204 [07:46<00:31, 3.13s/it] Importing SHAP scores: 96%|█████████▌| 195/204 [07:47<00:21, 2.41s/it] Importing SHAP scores: 96%|█████████▌| 196/204 [07:49<00:18, 2.32s/it] Importing SHAP scores: 97%|█████████▋| 197/204 [07:52<00:17, 2.44s/it] Importing SHAP scores: 97%|█████████▋| 198/204 [07:53<00:11, 1.93s/it] Importing SHAP scores: 98%|█████████▊| 199/204 [07:58<00:14, 2.96s/it] Importing SHAP scores: 98%|█████████▊| 200/204 [08:04<00:15, 3.91s/it] Importing SHAP scores: 99%|█████████▊| 201/204 [08:10<00:13, 4.50s/it] Importing SHAP scores: 99%|█████████▉| 202/204 [08:19<00:11, 5.75s/it] Importing SHAP scores: 100%|█████████▉| 203/204 [08:23<00:05, 5.36s/it] Importing SHAP scores: 100%|██████████| 204/204 [08:24<00:00, 2.47s/it]
# Limit SHAP coordinates/scores to only those with matching peak coordinates
shap_coords_table = pd.DataFrame(shap_coords, columns=["chrom", "start", "end"])
peak_coords_table = peak_table[["chrom", "peak_start", "peak_end"]]
ind_pairs = peak_coords_table.reset_index().merge(
shap_coords_table.reset_index(), how="left", left_on=["chrom", "peak_start", "peak_end"],
right_on=["chrom", "start", "end"]
)[["index_x", "index_y"]].values
# ind_pairs contains pairs (i, j) such that peak_table[i] matches shap_coords_table[j]
# If peak_table[i] did not match a SHAP coord, then j will be NaN
ind_pairs = ind_pairs[np.isfinite(ind_pairs[:, 1])].astype(int) # Remove peak indices with no matches
order_inds = np.full(len(peak_table), -1)
order_inds[ind_pairs[:, 0]] = ind_pairs[:, 1]
# If order_inds[i] == j, then peak_table[i] matches shap_coords_table[j]
# Unless a SHAP coord doesn't match a peak, in which case order_inds[i] == -1
shap_coords = shap_coords[order_inds]
hyp_scores = hyp_scores[order_inds]
act_scores = act_scores[order_inds]
one_hot_seqs = one_hot_seqs[order_inds]
bg_freq = np.mean(one_hot_seqs, axis=(0, 1))
# Whenever a SHAP coord did not exist in the peak table, set to 0
# This ensures that when we search for matches of DeepSHAP scores that don't
# exist, we will find nothing
shap_coords[order_inds < 0] = 0
hyp_scores[order_inds < 0] = 0
act_scores[order_inds < 0] = 0
one_hot_seqs[order_inds < 0] = 0
# Import motif hits results
hit_table = import_motif_hits(motif_hits_path)
TF-MoDISco hits
# Trim motifs properly
if "agg_sim" in list(hit_table):
# Perform trimming: first to 25 bp, then by IC (as when hit scoring was performed)
window_size = 25
min_ic = 0.2
for motif_key in motif_keys:
pfm = pfms[motif_key]
# First level trimming
ic = compute_per_position_ic(pfm, bg_freq, 0.001)
start = np.argmax(cpu_sliding_window_sum(ic, window_size))
end = start + window_size
pfm = pfm[start:end]
cwm = cwms[motif_key][start:end]
hcwm = hcwms[motif_key][start:end]
# Second level trimming
ic = compute_per_position_ic(pfm, bg_freq, 0.001)
pass_inds = np.where(ic >= min_ic)[0]
start, end = np.min(pass_inds), np.max(pass_inds) + 1
pfms[motif_key] = pfm[start:end]
cwms[motif_key] = cwm[start:end]
hcwms[motif_key] = hcwm[start:end]
else:
for motif_key in motif_keys:
pfm = pfms[motif_key]
pfms[motif_key] = trim_motif_by_ic(pfm, pfm)
cwms[motif_key] = trim_motif_by_ic(pfm, cwms[motif_key])
hcwms[motif_key] = trim_motif_by_ic(pfm, hcwms[motif_key])
# Filter motif hit table
if "agg_sim" in list(hit_table):
hit_table_filtered, filter_figs = filter_tfm_peak_hits(
hit_table, shap_coords, act_scores, cwms,
imp_thresh=motif_tfm_imp_prob_cutoff, sim_thresh=motif_tfm_sim_prob_cutoff,
save_imp_thresh=motif_tfm_save_imp_prob_cutoff
)
else:
hit_table_filtered, score_fig = filter_moods_peak_hits(
hit_table, imp_perc_cutoff=motif_moods_imp_perc_cutoff
)
Hit number reduction: 34280 -> 31965 (-0.067532) Proportion of peaks reduction: 0.151188 -> 0.143112
Hit number reduction: 116452 -> 82791 (-0.289055) Proportion of peaks reduction: 0.366954 -> 0.300859
Hit number reduction: 30940 -> 28296 (-0.085456) Proportion of peaks reduction: 0.135232 -> 0.125908
Hit number reduction: 59318 -> 36787 (-0.379834) Proportion of peaks reduction: 0.238885 -> 0.164104
Hit number reduction: 87 -> 66 (-0.241379) Proportion of peaks reduction: 0.000427 -> 0.000324
Hit number reduction: 213 -> 198 (-0.070423) Proportion of peaks reduction: 0.001037 -> 0.000968
Hit number reduction: 284 -> 6 (-0.978873) Proportion of peaks reduction: 0.001390 -> 0.000029
assert not hit_table_filtered.empty, "Filtered out all %d original hits" % len(hit_table)
# Match peaks to motif hits
peak_hits = get_peak_hits(peak_table, hit_table_filtered)
# Construct count array of peaks and hits
peak_hit_counts = get_peak_motif_counts(peak_hits, motif_keys)
# Construct count matrix of motif co-occurrence
motif_cooccurrence_count_matrix = get_motif_cooccurrence_count_matrix(peak_hit_counts)
# Construct the matrix of p-values for motif co-occurrence
motif_cooccurrence_pval_matrix = compute_cooccurrence_pvals(peak_hit_counts)
if hits_cache_dir:
# Save the filtered hits in the cache
hit_table_filtered.reset_index().to_csv(
os.path.join(hits_cache_dir, "filtered_hits.tsv"), sep="\t", header=True, index=False
)
# Save the peaks
peak_table.reset_index().to_csv(
os.path.join(hits_cache_dir, "peaks.tsv"), sep="\t", header=True, index=False
)
# Save a mapping between peak index and filtered motif indices
with open(os.path.join(hits_cache_dir, "peak_matched_hits.tsv"), "w") as f:
f.write("peak_index\tfiltered_hit_indices\n")
for i, table in enumerate(peak_hits):
f.write("%d\t%s\n" % (i, ",".join([str(x) for x in peak_hits[i].index])))
# Save score figures
if "agg_sim" in list(hit_table):
for motif_key, fig in filter_figs.items():
fig.savefig(os.path.join(hits_cache_dir, "filter_dists_%s.png" % motif_key))
else:
score_fig.savefig(os.path.join(hits_cache_dir, "imp_score_dist.png"))
# Save co-occurrence matrices
with h5py.File(os.path.join(hits_cache_dir, "cooccurrences.h5"), "w") as f:
f.create_dataset("counts", data=motif_cooccurrence_count_matrix, compression="gzip")
f.create_dataset("pvals", data=motif_cooccurrence_pval_matrix, compression="gzip")
motifs_per_peak = np.array([len(hits) for hits in peak_hits])
display(vdomh.p("Number of peaks: %d" % len(peak_table)))
display(vdomh.p("Number of motif hits before FDR filtering: %d" % len(hit_table)))
display(vdomh.p("Number of motif hits after FDR filtering: %d" % len(hit_table_filtered)))
Number of peaks: 203554
Number of motif hits before FDR filtering: 241574
Number of motif hits after FDR filtering: 180109
num_zero = np.sum(motifs_per_peak == 0)
display(vdomh.p("Number of peaks with 0 motif hits: %d" % num_zero))
display(vdomh.p("Percentage of peaks with 0 motif hits: %.1f%%" % (num_zero / len(peak_table) * 100)))
Number of peaks with 0 motif hits: 90995
Percentage of peaks with 0 motif hits: 44.7%
quants = [0, 0.25, 0.50, 0.75, 0.99, 1]
header = vdomh.thead(
vdomh.tr(
vdomh.th("Quantile", style={"text-align": "center"}),
vdomh.th("Number of hits/peak", style={"text-align": "center"})
)
)
body = vdomh.tbody(*([
vdomh.tr(
vdomh.td("%.1f%%" % (q * 100)), vdomh.td("%d" % v)
) for q, v in zip(quants, np.quantile(motifs_per_peak, quants))
]))
vdomh.table(header, body)
Quantile | Number of hits/peak |
---|---|
0.0% | 0 |
25.0% | 0 |
50.0% | 1 |
75.0% | 1 |
99.0% | 4 |
100.0% | 10 |
fig, ax = plt.subplots(figsize=(10, 10))
bins = np.concatenate([np.arange(np.max(motifs_per_peak) + 1), [np.inf]])
ax.hist(motifs_per_peak, bins=bins, density=True, histtype="step", cumulative=True)
ax.set_title("Cumulative distribution of number of motif hits per peak")
ax.set_xlabel("Number of motifs k in peak")
ax.set_ylabel("Proportion of peaks with at least k motifs")
plt.show()
if hits_cache_dir:
fig.savefig(os.path.join(hits_cache_dir, "peak_hit_count_cdf.png"))
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
frac_peaks_with_motif = np.sum(peak_hit_counts > 0, axis=0) / len(peak_hit_counts)
labels = np.array(motif_keys)
sorted_inds = np.flip(np.argsort(frac_peaks_with_motif))
frac_peaks_with_motif = frac_peaks_with_motif[sorted_inds]
labels = labels[sorted_inds]
fig, ax = plt.subplots(figsize=(20, 8))
ax.bar(np.arange(len(labels)), frac_peaks_with_motif)
ax.set_title("Proportion of peaks with each motif")
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
plt.show()
if hits_cache_dir:
fig.savefig(os.path.join(hits_cache_dir, "peaks_with_each_motif.png"))
# Show some examples of sequences with motif hits
num_to_draw = 3
center_plot_size = 400
unique_counts = np.sort(np.unique(motifs_per_peak))
motif_nums = []
if 0 in motifs_per_peak:
motif_nums.append(0)
if 1 in motifs_per_peak:
motif_nums.append(1)
motif_nums.extend([
unique_counts[0], # Minimum
unique_counts[len(unique_counts) // 2], # Median
unique_counts[-1], # Maximum
])
for motif_num in np.sort(np.unique(motif_nums)):
display(vdomh.h4("Sequences with %d motif hits" % motif_num))
peak_inds = np.where(motifs_per_peak == motif_num)[0]
table_rows = []
rng = np.random.RandomState(seed)
for i in rng.choice(
peak_inds, size=min(num_to_draw, len(peak_inds)), replace=False
):
peak_coord = peak_table.iloc[i][["chrom", "peak_start", "peak_end"]].values
motif_hits = peak_hits[i]
chrom, peak_start, peak_end = peak_coord
# Limit peak start/end here
mid = (peak_start + peak_end) // 2
peak_start = mid - (center_plot_size // 2)
peak_end = peak_start + center_plot_size
peak_len = peak_end - peak_start
mask = (shap_coords[:, 0] == chrom) & (shap_coords[:, 1] <= peak_start) & (shap_coords[:, 2] >= peak_end)
if not np.sum(mask):
fig = "No matching input sequence found"
table_rows.append(
vdomh.tr(
vdomh.td("%s:%d-%d" % (chrom, peak_start, peak_end)),
vdomh.td(fig)
)
)
continue
seq_index = np.where(mask)[0][0] # Pick one
imp_scores = act_scores[seq_index]
_, seq_start, seq_end = shap_coords[seq_index]
highlights = []
for _, row in motif_hits.iterrows():
start = row["start"] - peak_start
end = start + (row["end"] - row["start"])
highlights.append((start, end))
# Remove highlights that overrun the sequence
highlights = [(a, b) for a, b in highlights if a >= 0 and b < peak_len]
start = peak_start - seq_start
end = start + peak_len
imp_scores_peak = imp_scores[start:end]
fig = viz_sequence.plot_weights(
imp_scores_peak, subticks_frequency=(len(imp_scores_peak) + 1),
highlight={"red" : [pair for pair in highlights]},
return_fig=True
)
fig = figure_to_vdom_image(fig)
table_rows.append(
vdomh.tr(
vdomh.td("%s:%d-%d" % (chrom, peak_start, peak_end)),
vdomh.td(fig)
)
)
table = vdomh.table(*table_rows)
display(table)
plt.close("all")
chr2:42342477-42342877 | |
chr12:26937822-26938222 | |
chr12:57242589-57242989 |
chr1:35268378-35268778 | |
chr5:139710424-139710824 | |
chr2:60350798-60351198 |
chr16:49992723-49993123 | |
chr19:50203326-50203726 | |
chr17:842267-842667 |
chr10:1189751-1190151 |
density_figs = plot_homotypic_densities(peak_hit_counts, motif_keys)
if hits_cache_dir:
for key, fig in density_figs.items():
fig.savefig(os.path.join(hits_cache_dir, "homotypic_density_%s.png" % key))
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
/users/amtseng/miniconda3/envs/tfmodisco-mini/lib/python3.7/site-packages/matplotlib/axes/_axes.py:6662: RuntimeWarning: invalid value encountered in multiply tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
fig = plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_keys, subsample=10000)
if hits_cache_dir:
fig.savefig(os.path.join(hits_cache_dir, "peak_motif_indicator_heatmap.png"))
p_fig, c_fig, cluster_inds = plot_motif_cooccurrence_heatmaps(
motif_cooccurrence_count_matrix, motif_cooccurrence_pval_matrix, motif_keys
)
if hits_cache_dir:
p_fig.savefig(os.path.join(hits_cache_dir, "cooccurrence_pvals.png"))
c_fig.savefig(os.path.join(hits_cache_dir, "cooccurrence_counts.png"))
When motifs co-occur, show the distance between the instances
# Get which pairs of motifs are significant
sig_thresh = 1e-6
count_thresh = 100
pvals, sig_pairs = [], []
for i in range(len(motif_keys)):
for j in range(i + 1):
if motif_cooccurrence_pval_matrix[i, j] < sig_thresh and motif_cooccurrence_count_matrix[i, j] >= count_thresh:
sig_pairs.append((i, j))
pvals.append(motif_cooccurrence_pval_matrix[i, j])
inds = np.argsort(pvals)
sig_pairs = [sig_pairs[i] for i in inds]
distance_dict, fig = plot_intermotif_distance_violins(peak_hits, motif_keys, sig_pairs, cluster_inds)
if hits_cache_dir:
if fig is not None:
with h5py.File(os.path.join(hits_cache_dir, "intermotif_dists.h5"), "w") as f:
for key_pair, dists in distance_dict.items():
f.create_dataset("%s:%s" % key_pair, data=dists, compression="gzip")
fig.savefig(os.path.join(hits_cache_dir, "intermotif_dists.png"))