from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'
motifs = OrderedDict([
("Oct4-Sox2", 'Oct4/m0_p0'),
("Oct4", 'Oct4/m0_p1'),
# ("Strange-sym-motif", 'Oct4/m0_p5'),
("Sox2", 'Sox2/m0_p1'),
("Nanog", 'Nanog/m0_p1'),
("Zic3", 'Nanog/m0_p2'),
("Nanog-partner", 'Nanog/m0_p4'),
("Klf4", 'Klf4/m0_p0'),
])
from basepair.imports import *
ddir = get_data_dir()
from basepair.BPNet import BPNetPredictor
from basepair.plot.profiles import extract_signal
from basepair.math import softmax
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.tracks import plot_tracks, filter_tracks
from basepair.preproc import rc_seq
from basepair.exp.chipnexus.simulate import (insert_motif, generate_sim, plot_sim, generate_seq,
model2tasks, motif_coords, interactive_tracks, plot_motif_table,
plot_sim_motif_col)
from scipy.fftpack import fft, ifft
from basepair.exp.paper.config import *
create_tf_session(1)
model_dir = models_dir / exp
model = SeqModel.from_mdir(model_dir)
from basepair.exp.chipnexus.simulate import generate_seq, postproc, average_profiles, flatten
from basepair.exp.chipnexus.simulate import *
seqs = encodeDNA([generate_seq("ACGAT", side_motif="ACG",
side_distances=[10], seqlen=1000)
for i in range(100)])
model.predict(seqs).keys()
imps.keys()
imps.keys()
imps = model.imp_score_all(seqs, 'deeplift')
def sim_pred(self, central_motif, side_motif=None, side_distances=[], repeat=128, importance=[]):
"""
Args:
importance: list of importance scores
"""
# TODO - update?
from basepair.exp.chipnexus.simulate import generate_seq, postproc, average_profiles, flatten
batch_size = repeat
seqlen = self.seqlen
tasks = self.tasks
# simulate sequence
seqs = encodeDNA([generate_seq(central_motif, side_motif=side_motif,
side_distances=side_distances, seqlen=seqlen)
for i in range(repeat)])
# get predictions
preds = self.predict(seqs, batch_size=batch_size)
scaled_preds = {t: preds[f'{t}/profile'] * np.exp(preds[f'{t}/counts'][:, np.newaxis])
for t in tasks}
if importance:
# get the importance scores
imp_scores_all = self.imp_score_all(seqs)
imp_scores = {t: {imp_score_name.split("/")[0]: seqs * imp_scores_all[f'{t}/{imp_score_name}']
for imp_score_name in importance}
for t in tasks}
# merge and aggregate the profiles
out = {"imp": imp_scores, "profile": scaled_preds}
else:
out = scaled_preds
return average_profiles(flatten(out, "/"))
def input_seqlen(self):
return self.seqlen
# add the method
import types
model.sim_pred = types.MethodType(sim_pred, model)
model.input_seqlen = types.MethodType(input_seqlen, model)
tasks = model.tasks
central_motif = "ATTTGCATAACAAAG"
side_motif = "ATTTGCATAACAAAG"
insert_motif("ACGT", "GG", 1)
insert_motif("ACGT", "GG", 2)
insert_motif("ACGT", "GG", 3)
generate_seq("..", "--", [5], 20)
generate_seq("..", "--", [14], 20)
plot_tracks(model.sim_pred(""), fig_width=15, fig_height_per_track=1.5);
plot_tracks(filter_tracks(model.sim_pred("TTTGCATAACAA", importance=['profile/wn']), xlim=[400, 600]),
fig_width=15, fig_height_per_track=1.5);
df_d = {}
res_dict_d = {}
cache_path = f"{ddir}/cache/chipnexus/{exp}/simulation/spacing.pkl"
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
cached = False
if cached:
df_d, res_dict_d = read_pkl(cache_path)
# write_pkl((df_d, res_dict_d), cache_path)
!du -sh {cache_path}
side_motifs = OrderedDict([
("Oct4-Sox2", ("m0_p0", "TTTGCATAACAA")),
("Sox2", ("m0_p2", "CCATTGTT")),
("Err2", ("m0_p1", "TCAAGGTCA")),
("Nanog", ("m0_p2", "TTGATGGC")),
("Klf4", ("m2_p0", "GGGTGTGG")),
])
assert rc_seq("TATCG") == "CGATA"
# get also the rc motifs
rc_side_motifs = OrderedDict([
(m + "/rc", (v[0], rc_seq(v[1])))
for m,v in side_motifs.items()
])
all_side_motifs = OrderedDict(list(side_motifs.items()) + list(rc_side_motifs.items()))
central_motif_name = "Oct4"
central_motif = "TTTGCATAACAA"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords,
importance=['counts/pre-act', 'profile/wn']))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
Nanog strongly interacts with Oct4-Sox2 complex
central_motif_name = "Sox2"
central_motif = "CCATTGTT"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords,
importance=['counts/pre-act', 'profile/wn']))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Sox2'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
central_motif_name = "Nanog"
central_motif = "TTGATGGC"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords,
importance=['counts/pre-act', 'profile/wn']))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
central_motif_name = "Klf4"
central_motif = "GGGTGTGG"
a=1
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords,
importance=['counts/pre-act', 'profile/wn']))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
dfa = pd.concat([df_d[m].assign(central_motif_name=m) for m in df_d])
side_motifs = list(side_motifs)
df_d.keys()
# row = side motif
# column = main motif
# plot = counts of which protein are looked at
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
ax.plot(np.abs(np.fft.rfft(dfa[(dfa.task == "Nanog") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac)[3:74])**2, "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Oct4") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
a=1
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Sox2") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Nanog") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Klf4") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)