from basepair.imports import *
import pybedtools
from genomelake.extractors import BigwigExtractor
from basepair.extractors import StrandedBigWigExtractor, bw_extract
from kipoiseq.transforms import ResizeInterval
from basepair.modisco.table import ModiscoData
from pybedtools import BedTool
from basepair.cli.modisco import load_ranges
from basepair.plot.heatmaps import heatmap_importance_profile, normalize
from basepair.plot.profiles import plot_stranded_profile
df = pd.read_csv(f"{ddir}/processed/chipnexus/external-data.tsv", sep='\t')
df
import dask
from dask.diagnostics import ProgressBar
#from joblib import Parallel, delayed
class MultiAssayExtractor:
def __init__(self, df, interval_transform, use_strand=True, n_jobs=1):
self.n_jobs = n_jobs
self.df = df
self.interval_transform = interval_transform
self.use_strand = use_strand
def extract(self, intervals, progbar=False):
with ProgressBar():
with dask.config.set(num_workers=self.n_jobs, scheduler='multiprocessing'):
extracted = [dask.delayed(bw_extract)(fname, intervals,
self.interval_transform,
self.use_strand)
for fname in self.df.path]
extracted = dask.compute(*extracted)
d = {}
for assay in df.assay.unique():
idx = df[df.assay == assay].index
d[assay] = sum([x for i, x in enumerate(extracted)
if i in idx])
return d
pattern = 'm0_p0'
mdir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile/")
!ls {mdir}
mr = ModiscoResult(mdir / 'modisco.h5')
mr.open()
extractor = MultiAssayExtractor(df, ResizeInterval(6000), use_strand=True, n_jobs=10)
center = extractor.interval_transform.width // 2
from basepair.modisco.utils import shorten_pattern, longer_pattern
len(mr.patterns())
mr.plot_pattern(longer_pattern(pattern), 'seq');
from basepair.extractors import Interval
from basepair.plot.profiles import multiple_plot_stranded_profile
from basepair.plot.heatmaps import multiple_heatmap_importance_profile
cache = True
profiles = {}
pattern = "m0_p0"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m3_p0"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m0_p1"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m0_p3"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m1_p0"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m2_p0"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
pattern = "m0_p2"
mr.plot_pattern(longer_pattern(pattern), ['seq', 'contrib']);
intervals = mr.get_seqlet_intervals(longer_pattern(pattern))
%%time
if pattern not in profiles or not cache:
profiles[pattern] = extractor.extract(intervals, progbar=True)
o = profiles[pattern]
multiple_plot_stranded_profile(o);
sort_idx = np.argsort(-o['DNase'].sum(axis=1))
multiple_heatmap_importance_profile({k: normalize(v, 10,99) for k,v in o.items()},
sort_idx=sort_idx, figsize=(25,25), tick_step=1000, aspect=1);
# DNase footprint
plt.plot(o['DNase'].mean(0)[(center -100):(center+100)]);
heatmap_importance_profile(normalize(o['DNase'][sort_idx[:1000], (center -100):(center+100)], pmin=50, pmax=99), figsize=(10,10))
from basepair.utils import write_pkl
write_pkl(profiles, mdir / "pattern-meta-profiles.pkl")
!du -sh {mdir}/*.pkl