from basepair.imports import *
ddir = get_data_dir()
from basepair.BPNet import BPNetPredictor
from basepair.plot.profiles import extract_signal
from basepair.math import softmax
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.tracks import plot_tracks, filter_tracks
from basepair.preproc import rc_seq
from basepair.exp.chipnexus.simulate import (insert_motif, generate_sim, plot_sim, generate_seq,
model2tasks, motif_coords, interactive_tracks, plot_motif_table,
plot_sim_motif_col)
from scipy.fftpack import fft, ifft
create_tf_session(1)
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
# Load the model
model = load_model(model_dir / "model.h5")
model = BPNetPredictor(model)
tasks = model.tasks
central_motif = "ATTTGCATAACAAAG"
side_motif = "ATTTGCATAACAAAG"
insert_motif("ACGT", "GG", 1)
insert_motif("ACGT", "GG", 2)
insert_motif("ACGT", "GG", 3)
generate_seq("..", "--", [5], 20)
generate_seq("..", "--", [14], 20)
plot_tracks(model.sim_pred(""), fig_width=15, fig_height_per_track=1.5);
plot_tracks(filter_tracks(model.sim_pred("TTTGCATAACAA", importance=['weighted']), xlim=[400, 600]),
fig_width=15, fig_height_per_track=1.5);
mr = ModiscoResult(model_dir / f"modisco/by_peak_tasks/weighted/Oct4/modisco.h5")
mr.open()
df_d = {}
res_dict_d = {}
cache_path = f"{ddir}/cache/chipnexus/simulation/spacing.pkl"
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
cached = True
if cached:
df_d, res_dict_d = read_pkl(cache_path)
# write_pkl((df_d, res_dict_d), cache_path)
!du -sh {cache_path}
side_motifs = OrderedDict([
("Oct4-Sox2", ("m0_p0", "TTTGCATAACAA")),
("Sox2", ("m0_p2", "CCATTGTT")),
("Err2", ("m0_p1", "TCAAGGTCA")),
("Nanog", ("m0_p2", "TTGATGGC")),
("Klf4", ("m2_p0", "GGGTGTGG")),
])
assert rc_seq("TATCG") == "CGATA"
# get also the rc motifs
rc_side_motifs = OrderedDict([
(m + "/rc", (v[0], rc_seq(v[1])))
for m,v in side_motifs.items()
])
all_side_motifs = OrderedDict(list(side_motifs.items()) + list(rc_side_motifs.items()))
central_motif_name = "Oct4"
central_motif = "TTTGCATAACAA"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
Nanog strongly interacts with Oct4-Sox2 complex
central_motif_name = "Sox2"
central_motif = "CCATTGTT"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Sox2'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
central_motif_name = "Nanog"
central_motif = "TTGATGGC"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
central_motif_name = "Klf4"
central_motif = "GGGTGTGG"
plot_tracks(model.sim_pred(central_motif), fig_width=15, fig_height_per_track=1.5,
title=f"{central_motif_name}: {central_motif}", same_ylim=True);
plt.xlabel("Position");
# find the right profile crop
center_coords = [485, 520]
plot_tracks(filter_tracks(model.sim_pred(central_motif), center_coords), fig_width=5, fig_height_per_track=1.5);
# get the motifs
if not cached:
res_dict = OrderedDict([(motif, generate_sim(model, central_motif, side_motif, list(range(511, 511+150, 1)),
center_coords=center_coords))
for motif, (pattern, side_motif) in all_side_motifs.items()])
df = pd.concat([v[0].assign(motif=k) for k,v in res_dict.items()]) # stack the dataframes
df_d[central_motif_name] = df
res_dict_d[central_motif_name] = res_dict
df = df_d[central_motif_name]
res_dict = res_dict_d[central_motif_name]
display(plot_motif_table(mr, side_motifs))
plot_sim_motif_col(df, tasks, ['profile/max_frac', 'profile/counts_frac', 'profile/simmetric_kl',
'imp/count', 'imp/weighted'],
motifs=list(all_side_motifs), subfigsize=(6, 3), alpha=1)
side_motif_name = 'Nanog'; interactive_tracks(res_dict[side_motif_name][1], central_motif, all_side_motifs[side_motif_name][1])
dfa = pd.concat([df_d[m].assign(central_motif_name=m) for m in df_d])
side_motifs = list(side_motifs)
df_d.keys()
# row = side motif
# column = main motif
# plot = counts of which protein are looked at
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
ax.plot(np.abs(np.fft.rfft(dfa[(dfa.task == "Nanog") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac)[3:74])**2, "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Oct4") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Sox2") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Nanog") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)
central_motifs = ['Sox2', 'Oct4', 'Nanog', 'Klf4']
fig, axes = plt.subplots(len(side_motifs),len(central_motifs), figsize=(25, 10), sharex=True, sharey='row')
for j, central_motif in enumerate(central_motifs):
for i, m in enumerate(side_motifs):
ax = axes[i, j]
if i == 0:
ax.set_title(central_motif)
power_spec = np.abs(np.fft.rfft(dfa[(dfa.task == "Klf4") & (dfa.motif == m) &(dfa.central_motif_name == central_motif)].counts_frac))**2
power_spec = power_spec / power_spec[1:].sum()
ax.plot(power_spec[2:], "-o")
ax.set_xlabel("Frequency [bp]")
if j == 0:
ax.set_ylabel(m)
fig.subplots_adjust(wspace=0, hspace=0)