# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
# Common paths
models_dir = Path(f"{ddir}/processed/chipnexus/exp/models/osnk-pstat-sall-smad-zfp/models/")
model_dir = models_dir / "c_task_weight=5,filters=128/"
modisco_dir = model_dir / f"modisco/deeplift/profile/"
# Common paths
models_dir = Path(f"{ddir}/processed/chipnexus/exp/models/osnk-pstat-sall-smad-zfp/models/")
model_dir = models_dir / "best/"
modisco_dir = model_dir / f"modisco/deeplift/profile/"
from basepair.cli.imp_score import ImpScoreFile
import pybedtools
from basepair.plot.tracks import filter_tracks
from basepair.utils import flatten_list
paper_config()
bpnet = BPNet.from_mdir(model_dir)
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
def plot_region(bpnet, interval, variants=None):
pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
viz_dict = OrderedDict(flatten_list([[
(f"{task} Pred", pred['pred'][task]),
(f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] for task_idx, task in enumerate(bpnet.tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
# Hard-code the range
fmax = {'Imp profile': 0.265742, 'Pred': 34.442215}
fmin = {'Imp profile': -0.16751188, 'Pred': 0.07202636}
#fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
# for feature in ['Imp profile', 'Pred']}
#fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
# for feature in ['Imp profile', 'Pred']}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
return fig
plot_region(bpnet, interval);
plot_region(bpnet, interval);