from basepair.imports import *
from basepair.samplers import top_sum_count, top_max_count
from basepair.plot.tracks import plot_tracks, filter_tracks
from kipoi.data_utils import get_dataset_item
from basepair.modisco.table import ModiscoData
from plotnine import *
import plotnine
model_dir = Path(f"{ddir}/processed//chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / "modisco/by_peak_tasks/weighted/Oct4"
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
# load all of modisco data
md = ModiscoData.load(modisco_dir, model_dir / "grad.all.h5")
# load all of modisco data
mr = ModiscoResult(model_dir / "modisco/valid/modisco.h5")
mr.open()
model_dir = Path(f"{ddir}/processed//chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / "modisco/by_peak_tasks/weighted/Nanog"
dfm = md.get_centroid_seqlet_matches
dfm =
dfm.head()
# TODO - get the whole table
dfm = md.get_centroid_seqlet_matches()
dfm.to_csv(modisco_dir / "centroid_seqlet_matches.csv.bak")
pattern = 'metacluster_0/pattern_0'
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
]
cached = True
cache_dir = Path(Path(ddir) / "cache/dev/")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / "hit-scoring-contrib-only.hdf5"
dfp = pd.read_hdf(cache_file, "/scores")
# add some additional features to dfp
dfp['idx'] = np.arange(len(dfp))
dfp = dfp[dfp.score_seq_match > 0]
# -----------------------------------------------------------------
dfp['log_imp_weighted'] = np.log(1+dfp['imp_weighted'])
dfp['log_match_weighted'] = np.log(1+dfp['match_weighted'])
# partition the scores into 4 categories (according to the median)
dfp_medians = pd.DataFrame({'log_match_weighted': dfp.groupby("tf").log_match_weighted.median(),
'log_imp_weighted': dfp.groupby("tf").log_imp_weighted.median()}).reset_index()
log_imp_weighted_median = dfp.tf.map(dfp.groupby("tf").log_imp_weighted.median())
log_match_weighted_median = dfp.tf.map(dfp.groupby("tf").log_match_weighted.median())
dfp['imp_cat'] = (dfp.log_imp_weighted > log_imp_weighted_median).map({False: 'low', True: 'high'})
dfp['match_cat'] = (dfp.log_match_weighted > log_match_weighted_median).map({False: 'low', True: 'high'})
dfp_subset = dfp.sample(10000) # focus on the subset of the data
get the matrix of:
scan these sequences using your scores
pattern_name = 'metacluster_0/pattern_0'
pattern = md.mr.get_pattern(pattern_name).trim_seq_ic(0.08)
task = 'Oct4'
i, j = md.get_trim_idx(pattern_name)
seq = md.get_seq(pattern_name)[:, i:j]
profile = {task: md.get_profile_wide(pattern_name, task) for task in tasks}
contrib = {task: md.get_imp(pattern_name, task, 'profile')[:, i:j] for task in tasks}
plot_tracks(dict(recomputed=contrib[task].mean(axis=0), orig=pattern.contrib[task]));
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import plot_stranded_profile
plot_stranded_profile(profile[task].mean(axis=0))
heatmap_stranded_profile(profile[task][np.argsort(-profile[task].sum((1,2)))]);
tasks = md.tasks
match, importance = pattern.scan_importance(contrib, hyp_contrib=None, tasks=tasks,
n_jobs=1, verbose=True, pad_mode=None)
seq_match = pattern.scan_seq(seq, n_jobs=1, verbose=True, pad_mode=None)
dfm = pattern.get_instances(tasks, match, importance, seq_match, fdr=1, verbose=True)
print(f"Discarding {np.mean(dfm.score_seq_match < 0):.2%} of points with poor sequence match")
dfm = dfm[dfm.score_seq_match > 0]
dfm.head()
fig ,axes = plt.subplots(1, 3, figsize=(12, 3))
dfm.match_weighted.plot.hist(100, ax=axes[0]);
axes[0].set_xlabel("Match");
dfm.imp_weighted.plot.hist(100, ax=axes[1]);
axes[1].set_xlabel("Importance");
dfm.score_seq_match.plot.hist(100, ax=axes[2]);
axes[2].set_xlabel("Sequence match");
plt.tight_layout()
fig ,axes = plt.subplots(1, 3, figsize=(12, 3))
dfmp = dfp_subset[(dfp_subset.pattern == pattern.name) & (dfp_subset.match_weighted > 0.4)]
dfmp = dfp[(dfp.pattern == pattern.name) & (dfp.match_weighted > 0.2)]
dfmp.match_weighted.plot.hist(100, ax=axes[0]);
axes[0].set_xlabel("Match");
dfmp.imp_weighted.plot.hist(100, ax=axes[1]);
axes[1].set_xlabel("Importance");
dfmp.score_seq_match.plot.hist(100, ax=axes[2]);
axes[2].set_xlabel("Sequence match");
plt.tight_layout()
from tqdm import tqdm
def get_all_dfm_seqlets(md, trim_frac=0.08, n_jobs=1):
return pd.concat([get_dfm_seqlets(md, pattern, trim_frac, n_jobs)
for pattern in tqdm(md.mr.patterns())])
def get_dfm_seqlets(md, pattern_name, trim_frac=0.08, n_jobs=1, verbose=False):
"""
Args:
md: ModiscoData
pattern_name
"""
tasks = md.tasks
pattern = md.mr.get_pattern(pattern_name).trim_seq_ic(trim_frac)
i, j = md.get_trim_idx(pattern_name)
seq = md.get_seq(pattern_name)[:, i:j]
profile = {task: md.get_profile_wide(pattern_name, task) for task in tasks}
contrib = {task: md.get_imp(pattern_name, task, 'profile')[:, i:j] for task in tasks}
match, importance = pattern.scan_importance(contrib, hyp_contrib=None, tasks=tasks,
n_jobs=n_jobs, verbose=False, pad_mode=None)
seq_match = pattern.scan_seq(seq, n_jobs=n_jobs, verbose=False, pad_mode=None)
dfm = pattern.get_instances(tasks, match, importance, seq_match, fdr=1, verbose=verbose, plot=verbose)
dfm = dfm[dfm.score_seq_match > 0]
return dfm
d = pd.Series(np.arange(10))
quantiles = [0, .33, .66, 1.0]
labels = ['low', 'medium', 'high']
percentile_steps = np.arange(0, 1.01, 0.01)
labels = percentile_steps[1:]
dfm = pd.read_csv(f"{modisco_dir}/../Nanog/centroid_seqlet_matches.csv")
q = d.quantile(quantiles)
pd.cut(d, q.values, labels=labels, include_lowest=True, right=True)
# TODO - get the quantile compared to another distribution - p_seqlet
d.buck
dfm = get_all_dfm_seqlets(md)
md.mr.fpath
dfm.tail()
def plot(dfm_seqlets, dfm_orig, title):
dfm = dfm_seqlets
fig ,axes = plt.subplots(1, 3, figsize=(12, 3))
dfm.match_weighted.plot.hist(100, ax=axes[0]);
axes[0].set_xlabel("Match");
axes[0].set_title(title);
dfm.imp_weighted.plot.hist(100, ax=axes[1]);
axes[1].set_xlabel("Importance");
dfm.score_seq_match.plot.hist(100, ax=axes[2]);
axes[2].set_xlabel("Sequence match");
plt.tight_layout()
fig ,axes = plt.subplots(1, 3, figsize=(12, 3))
dfmp = dfm_orig
if len(dfmp) == 0:
import pdb
pdb.set_trace()
threshold = dfmp[dfmp['match_cat']=='high'].match_weighted.min()
dfmp.match_weighted.plot.hist(100, ax=axes[0]);
axes[0].set_xlabel("Match");
axes[0].axvline(threshold, color='red')
dfmp.imp_weighted.plot.hist(100, ax=axes[1]);
threshold = dfmp[dfmp['imp_cat']=='high'].match_weighted.min()
axes[1].set_xlabel("Importance");
axes[1].axvline(threshold, color='red')
dfmp.score_seq_match.plot.hist(100, ax=axes[2]);
axes[2].set_xlabel("Sequence match");
plt.tight_layout()
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
]
for motif, pattern_name in pattern_names:
print(motif)
plot(get_dfm_seqlets(md, pattern_name, verbose=True), dfp_subset[(dfp_subset.pattern == pattern_name)], motif)
thresholds = {"Klf4": 0.55,
"Nanog": 0.45,
"Sox2": 0.55,
"Errb": 0.4,
"Oct4-Sox2": 0.4
}