from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'
motifs = OrderedDict([
("Oct4-Sox2", 'Oct4/m0_p0'),
("Oct4", 'Oct4/m0_p1'),
# ("Strange-sym-motif", 'Oct4/m0_p5'),
("Sox2", 'Sox2/m0_p1'),
("Nanog", 'Nanog/m0_p1'),
("Zic3", 'Nanog/m0_p2'),
("Nanog-partner", 'Nanog/m0_p4'),
("Klf4", 'Klf4/m0_p0'),
])
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.exp.paper.config import *
from basepair.extractors import Variant, extract_seq
hv.extension('bokeh')
import pybedtools
from basepair.utils import flatten_list
paper_config()
create_tf_session(0)
# figures dir
model_dir = models_dir / exp
fdir = Path(f'{ddir}/figures/modisco/{exp}/in-vivo-perturb/oct4-enhancer')
!mkdir -p {fdir}
from basepair.seqmodel import SeqModel
ds = DataSpec.load(rdir / 'src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml')
# Get counts
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
(420+60 + 35503550, 420+120 + 35503550)
from basepair.BPNet import BPNetSeqModel
bpnet = BPNetSeqModel.from_mdir(model_dir)
# input_seqlen = 1000 - bpnet.body.get_len_change() - bpnet.heads[0].net.get_len_change()
from genomelake.extractors import FastaExtractor
fe = FastaExtractor(ds.fasta_file)
seq = fe([resize_interval(interval, input_seqlen)])
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
# Hack
bpnet.fasta_file = ds.fasta_file
bpnet.bias_model = None
pred = bpnet.predict_intervals([interval], imp_method='deeplift')[0]
def to_neg(track):
track = track.copy()
track[:, 1] = - track[:, 1]
return track
def plot_region(interval, variants=None, ):
pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
viz_dict = OrderedDict(flatten_list([[
(f"{task} Pred", to_neg(pred['pred'][task])),
(f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] for task_idx, task in enumerate(['Oct4', 'Nanog'])]))
viz_dict = filter_tracks(viz_dict, [420+60, 420+120])
# Hard-code the range
fmax = {'Imp profile': 0.265742, 'Pred': 6.5}
fmin = {'Imp profile': -0.16751188, 'Pred': -6.5}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((fmin[f], fmax[f]))
colors = []
for task in ['Oct4', 'Nanog']:
colors.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
colors.append(None)
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=.5)[0],
color=colors,
ylim=ylim,
use_spine_subset=True,
legend=False)
sns.despine(top=True, right=True, left=False, bottom=True)
return fig
fdir
plot_region(interval);
plt.savefig(fdir / 'wt.pdf')
plot_region(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);
plt.savefig(fdir / 'dOct4-Sox2.pdf')
plot_region(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);
plt.savefig(fdir / 'dNanog.pdf')
tf_colors
def plot_region_diff(interval, variants):
ref_pred = bpnet.predict_intervals([interval], variants=None, imp_method='deeplift')[0]
pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
viz_dict = OrderedDict(flatten_list([[
(f"{task} Pred", pred['pred'][task]),
(f"{task} ref - alt pred", ref_pred['pred'][task] - pred['pred'][task]),
(f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
(f"{task} ref - alt Imp", ref_pred['imp_score'][f"{task}/weighted"] * ref_pred['seq'] - pred['imp_score'][f"{task}/weighted"] * pred['seq']),
# (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
] for task_idx, task in enumerate(bpnet.tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
# Hard-code the range
fmax = {'Imp profile': 0.265742, 'Pred': 34.442215, }
fmin = {'Imp profile': -0.16751188, 'Pred': 0.07202636}
#fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
# for feature in ['Imp profile', 'Pred']}
#fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
# for feature in ['Imp profile', 'Pred']}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin['Imp profile'], fmax['Imp profile']))
else:
ylim.append((-fmax['Pred'], fmax['Pred']))
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
ylim=ylim,
legend=False)
return fig
plot_region_diff(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);
plot_region_diff(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
dfi = load_instances(modisco_dir / 'instances.parq', motifs=motifs, dedup=True)
from basepair.modisco.pattern_instances import (multiple_load_instances, load_instances, filter_nonoverlapping_intervals,
plot_coocurence_matrix, align_instance_center, dfi2seqlets, annotate_profile)
# instance_parq_paths = {t: model_dir / f'deeplift/{t}/out/{imp_score}/instances.parq'
# for t in tasks}
# dfi = multiple_load_instances(instance_parq_paths, motifs)
dfi.head()
# get all instances in this region
dfi_subset = dfi.query('example_chrom == "chr17" & pattern_start_abs > 35503550 + 420 & pattern_end_abs < 35503550 + 575')
# Subset the motifs
dfi_subset = dfi.query('match_weighted_p > .2').query('imp_weighted_p > .01')
# get all instances in this region
dfi_subset = dfi.query('example_chrom == "chr17" & pattern_start_abs > 35503550 + 420 & pattern_end_abs < 35503550 + 575')
dfi_subset['rel_center'] = dfi_subset['pattern_center'] + dfi_subset['example_start']- (35503550 +420)
dfi_subset['pattern_width'] = dfi_subset['pattern_end'] - dfi_subset['pattern_start']
dfi_subset[['pattern_name', 'strand', 'rel_center', 'pattern_width']]
oct_sox_row = dfi_subset.loc[26]
oct_sox_row
trimmed_interval = pybedtools.create_interval_from_list(['chr17', 35503550 +420, 35503550 +575])
seq = extract_seq(trimmed_interval, None, bpnet.fasta_file)
# Oct4-Sox2
seq[80:91]
# Nanog
seq[98:102]
import random
random_os = ''.join(random.choices("ACGT", k=11))
random_nanog = ''.join(random.choices("ACGT", k=4))
random_os
random_nanog