-
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
import pybedtools
from basepair.utils import flatten_list
paper_config()
create_tf_session(0)
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
# model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf-sall/models/default/")
modisco_dir = model_dir / f"modisco/all/profile/"
figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet"
!mkdir -p {figures}/known_enhancer_profiles
ds = DataSpec.load(model_dir / "dataspec.yaml")
# Get counts
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
# bpnet = BPNet.from_mdir(model_dir)
# # Get predictions
# pred = bpnet.predict_intervals([interval], imp_method='deeplift')[0]
tasks = bpnet.tasks
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
viz_dict = OrderedDict(flatten_list([[
(f"{task} Pred", pred['pred'][task]),
(f"{task} Obs", obs[task]),
(f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] for task_idx, task in enumerate(bpnet.tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
for feature in ['Imp profile', 'Obs', 'Pred']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
for feature in ['Imp profile', 'Obs', 'Pred']}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
# # mdir2 = Path('/users/avsec//workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/gt/run-20190219_113014-dgg341ac')
# # bpnet2_model = load_model(str(mdir2/ "model.h5"))
# m2_config = read_json(mdir2/ 'config.gin.json')
# input_seqlen=seq_bpnet_cropped_extra_seqlen(conv1_kernel_size=m2_config['conv1_kernel_size'],
# n_dil_layers=m2_config['n_dil_layers'],
# target_seqlen=m2_config['target_seqlen'],
# tconv_kernel_size=m2_config['tconv_kernel_size'])
# bpnet2 = BPNet(bpnet2_model, fasta_file=ds.fasta_file)
# bpnet2.imp_score(seq, task='Oct4', strand='both', method='deeplift', pred_summary='l2')
# mdir2 = Path('/users/avsec//workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/gt/run-20190204_113238-do8kugq4')
from basepair.BPNet import model2tasks
from basepair.models import seq_bpnet_cropped_extra_seqlen
from basepair.preproc import resize_interval
from basepair.seqmodel import SeqModel
from basepair.utils import unflatten
from genomelake.extractors import FastaExtractor
# mdir2 = Path('../train/seqmodel/output/seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE')
mdir2 = Path('../train/seqmodel/output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE')
# mdir2 = Path('../train/seqmodel/output/nexus,peaks,OSNK,0,10,1,TRUE,valid,0.5,64,25,0.004,9,FALSE')
# mdir2 = Path('../train/seqmodel/output/nexus,peaks,OSNK,0,10,1,FALSE,valid,0.5,64,25,0.004,9,')
sm = SeqModel.from_mdir(mdir2)
input_seqlen = 1000 - sm.body.get_len_change() - sm.heads[0].net.get_len_change()
fe = FastaExtractor(ds.fasta_file)
seq = fe([resize_interval(interval, input_seqlen)])
imp_scores = sm.imp_score_all(seq, preact_only=True)
# x = sm.neutral_bias_inputs(1000, 1000)
# x['seq'] = seq
# preds = sm.predict(x)
x = sm.neutral_bias_inputs(1000, 1000)
x['seq'] = seq
preds = sm.predict(x)
# TODO - put this somewhere
def trim_seq(seq_width, peak_width):
if seq_width > peak_width:
# Trim
# make sure we can nicely trim the peak
assert (seq_width - peak_width) % 2 == 0
trim_start = (seq_width - peak_width) // 2
trim_end = seq_width - trim_start
assert trim_end - trim_start == peak_width
elif seq_width == peak_width:
trim_start = 0
trim_end = peak_width
else:
raise ValueError("seq_width < peak_width")
return trim_start, trim_end
# TODO have the function to get the right trimming
trim_i,trim_j = trim_seq(input_seqlen, 1000)
trim_i,trim_j
from basepair.utils import unflatten
tasks = sm.tasks
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
viz_dict = OrderedDict(flatten_list([[
(f"{task} Pred", preds[f"{task}/profile"][0]),
(f"{task} Obs", obs[task]),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] + [(f"{task} {imp_score} Imp profile", (v * seq)[0, trim_i:trim_j])
for imp_score,v in unflatten(imp_scores, "/")[task]['profile'].items()
]
for task_idx, task in enumerate(tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
from kipoi.utils import unique_list
features = unique_list([k.split(" ", 1)[1] for k in viz_dict])
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
for feature in features}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
for feature in features}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
a = sm.all_heads['Oct4'][0]
seq.shape
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
viz_dict = OrderedDict(flatten_list([[
# (f"{task} Pred", pred['pred'][task]),
(f"{task} Obs", obs[task]),
(f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
(f"{task} Imp counts", pred['imp_score'][f"{task}/count"] * pred['seq']),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] for task_idx, task in enumerate(bpnet.tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
for feature in ['Imp profile', 'Imp counts', 'Obs']} # 'Pred',
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
for feature in ['Imp profile', 'Imp counts', 'Obs']} # 'Pred',
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)