from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'
motifs = OrderedDict([
("Oct4-Sox2", 'Oct4/m0_p0'),
# Others
("Oct4-Sox2/S", 'Sox2/m0_p0'),
("Oct4-Sox2/N", 'Nanog/m0_p0'),
("Oct4-Sox2/K", 'Klf4/m0_p1'),
("Oct4", 'Oct4/m0_p1'),
("Oct4-Oct4", 'Oct4/m0_p6'),
("B-Box", 'Oct4/m0_p5'),
("B-Box/S", 'Sox2/m0_p3'),
("B-Box/K", 'Klf4/m0_p11'),
("Sox2", 'Sox2/m0_p1'),
# Others
("Sox2/O", 'Oct4/m0_p3'),
("Sox2/N", 'Nanog/m0_p3'),
("Sox2/K", 'Klf4/m0_p8'),
("Nanog", 'Nanog/m0_p1'),
("Nanog2", 'Nanog/m0_p4'),
("Nanog-mix", 'Nanog/m0_p5'),
# TODO - Other Sox2
("Zic3", 'Nanog/m0_p2'),
("Zic3/K", 'Klf4/m0_p2'),
("Essrb", 'Oct4/m0_p16'),
("Klf4", 'Klf4/m0_p0'),
("Klf4/O", 'Oct4/m0_p4'),
("Klf4/S", 'Sox2/m0_p4'),
("Klf4-Klf4", 'Klf4/m0_p5'),
])
gpu = 2 # Set to None if GPU shouldn't be used
motifs_inv = {v:k for k,v in motifs.items()}
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.exp.paper.config import *
from basepair.seqmodel import SeqModel
from basepair.BPNet import model2tasks
from basepair.models import seq_bpnet_cropped_extra_seqlen
from basepair.preproc import resize_interval, parse_interval
from basepair.seqmodel import SeqModel
from basepair.utils import unflatten
from genomelake.extractors import FastaExtractor
from concise.preprocessing.sequence import one_hot2string, DNA
from kipoi.utils import unique_list
import pybedtools
from basepair.utils import flatten_list
paper_config()
if gpu is not None:
create_tf_session(gpu)
# Common paths
model_dir = models_dir / exp
# figures = f"{ddir}/figures/model-evaluation/chipnexus-bpnet/{exp}/known_enhancer_profiles"
figures = Path(f'{ddir}/figures/modisco/{exp}')
!mkdir -p {figures}/known_enhancer_profiles
# Dataspec
ds = DataSpec.load(rdir / 'src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml')
from basepair.modisco.results import MultipleModiscoResult
from basepair.modisco.pattern_instances import (multiple_load_instances, load_instances, filter_nonoverlapping_intervals,
plot_coocurence_matrix, align_instance_center, dfi2seqlets, annotate_profile)
def shorten_te_pattern(s):
tf, p = s.split("/", 1)
return tf + "/" + shorten_pattern(p)
mr = MultipleModiscoResult({t: model_dir / f'deeplift/{t}/out/{imp_score}/modisco.h5'
for t in tasks})
centroid_seqlet_matches = {t: pd.read_csv(model_dir / f'deeplift/{t}/out/{imp_score}/centroid_seqlet_matches.csv')
for t in tasks}
patterns = [p.trim_seq_ic(0.08) for p in mr.get_all_patterns()
if shorten_te_pattern(p.name) in list(motifs.values())]
bpnet = SeqModel.from_mdir(model_dir)
for p in patterns:
p.plot('seq_ic', height=0.4, letter_width=0.15);
sns.despine(top=True, bottom=True, right=True)
plt.title(motifs_inv[shorten_te_pattern(p.name)])
plt.ylim([0,2])
from basepair.exp.paper.locus import *
colors_track_only = []
for task in tasks:
colors_track_only.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
# Generate the right colors
colors = []
for task in tasks:
colors.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
colors.append(None)
interval = parse_interval("chr17:35503550-35504550")
# actual coordinates
35503550 + 420
35503550 + 575
viz_dict, seq, imp_scores = interval_predict(bpnet, ds, interval, tasks, incl_pred=True)
xlim = [420, 575] # Focus only on the 420 - 575 region
viz_dict = filter_tracks(viz_dict, xlim)
# instances
dfim = get_instances(patterns, seq, imp_scores, imp_score, centroid_seqlet_matches, motifs, tasks).query('match_weighted_p > .2')
seqlets = dfi2seqlets(dfim, motifs_inv)
seqlets2 = [s.shift(-xlim[0]) for s in seqlets]
print(one_hot2string(seq[:, slice(*xlim)], DNA)[0]) # Get the sequence
viz_dict_pred = OrderedDict([(k,v) for k,v in viz_dict.items() if not k.endswith("Obs")])
fig = plot_tracks(viz_dict_pred,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
seqlets=seqlets2,
color=colors,
ylim=get_ylim(viz_dict_pred, tasks),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/distal_oct4.predicted+importance.pdf")
viz_dict_pred_only = OrderedDict([(k,v) for k,v in viz_dict.items() if k.endswith("Pred")])
fig = plot_tracks(viz_dict_pred_only,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
seqlets=seqlets2,
color=colors_track_only,
ylim=get_ylim(viz_dict_pred_only, tasks),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/distal_oct4.predicted.pdf")
viz_dict_imp_only = OrderedDict([(k,v) for k,v in viz_dict.items() if k.endswith("Imp profile")])
fig = plot_tracks(viz_dict_imp_only,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
# seqlets=seqlets2,
color=None,
ylim=get_ylim(viz_dict_imp_only, tasks),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/distal_oct4.imp.pdf")
fig = plot_tracks(viz_dict_imp_only,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
seqlets=seqlets2,
color=None,
ylim=get_ylim(viz_dict_imp_only, tasks),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/distal_oct4.imp+instances.pdf")
viz_dict_obs_pred = OrderedDict([(k,v) for k,v in viz_dict.items() if not k.endswith("Imp profile")])
colors_track_only2 = []
for task in tasks:
colors_track_only2.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
colors_track_only2.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
fig = plot_tracks(viz_dict_obs_pred,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=.5)[0],
use_spine_subset=True,
# seqlets=seqlets2,
color=colors_track_only2,
ylim=get_ylim(viz_dict_obs_pred, tasks),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/distal_oct4.observed+pred.pdf")
ls {figures}/known_enhancer_profiles/all-motifs/
!mkdir -p {figures}/known_enhancer_profiles/all-motifs
from basepair.plot.tracks import *
def plot_seqlet_underscore(seqlet, ax, add_label=False):
"""
Args:
seqlet: object with start, end, name, strand attribues
if Seqname is available, then we can plot it to the right position
"""
xlim = ax.get_xlim()
xmin = seqlet.start + 0.5
xmax = seqlet.end + 0.5
if xmax < 0 or xmin > xlim[1] - xlim[0]:
return
# TODO - it would be nice to also have some differnet colors here
draw_hline(xmin, xmax, ax.get_ylim()[0], col='r', linewidth=5, alpha=0.3)
if add_label:
y = ax.get_ylim()[1] + (ax.get_ylim()[1] - ax.get_ylim()[0]) * 0.15
ax.text(xmin + 0.5, y,
s=seqlet.strand + str(seqlet.name),
fontsize=5)
df_enhancers = pd.read_csv("https://docs.google.com/spreadsheets/d/1nIRLv3tWq_3BjorP_pEAyJKWtVClX6_OrE2waN94ECc/export?gid=0&format=csv")
# # New regions to generate
# names = ['sall1 downstream_1',
# 'sall1 downstream_2',
# 'Sall1_as in fig 1',
# 'Dpp5a uptream',
# 'Fbxo15']
# df_enhancers = df_enhancers[df_enhancers.Name.isin(names)]
df_enhancer_intervals = df_enhancers[['Name', 'mm10 coordinates']].dropna()
# assert len(names) == len(df_enhancer_intervals)
intervals = [(row.Name, str(row[['mm10 coordinates']].iloc[0]).strip())
for i, row in df_enhancer_intervals.iterrows()]
sequences = dict()
for i, (name, interval_str) in enumerate(intervals):
print(f"{i}/{len(intervals)}", name)
interval = parse_interval(interval_str)
viz_dict, seq, imp_scores = interval_predict(bpnet, ds, interval, tasks)
dfim = get_instances(patterns, seq, imp_scores, imp_score, centroid_seqlet_matches, motifs, tasks).query('match_weighted_p > .2')
seqlets = dfi2seqlets(dfim, motifs_inv)
xlim = None
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title=str_interval(interval, xlim) + name,
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=2)[0],
use_spine_subset=True,
seqlets=seqlets,
color=colors,
ylim=get_ylim(viz_dict, tasks, True),
# seqlet_plot_fn=plot_seqlet_underscore,
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/{name},{interval_str}.1kb.pdf")
plt.close()
# Figure out the most interesting 150 bp in the entire 1kb region
# by looking at the total number of counts in the 150 bp window
from basepair.preproc import moving_average
# center = np.argmax(moving_average(sum([np.abs(viz_dict[f'{task} Obs']).sum(axis=-1) for task in tasks]), 150))
center = np.argmax(moving_average(sum([np.abs(viz_dict[f'{task} Imp profile']).sum(axis=-1) for task in tasks]), 150))
xlim = [center - 75, center + 75]
seqlets2 = [s.shift(-xlim[0]) for s in seqlets]
viz_dict2 = filter_tracks(viz_dict, xlim)
print(str_interval(interval, xlim), name)
seq_str = one_hot2string(seq[:, slice(*xlim)], DNA)[0]
print(seq_str) # Get the sequence
sequences[f"{name},{interval_str}"] = seq_str
fig = plot_tracks(viz_dict2,
#seqlets=shifted_seqlets,
title=str_interval(interval, xlim) + " (" + str(xlim) + ")",
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
seqlets=seqlets2,
color=colors,
ylim=get_ylim(viz_dict2, tasks, True),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs/{name},{interval_str}.150bp.pdf")
plt.close()
# Write out the fasta file
from concise.utils.fasta import write_fasta
write_fasta(f"{figures}/known_enhancer_profiles/all-motifs/sequences.150bp.new.fa", list(sequences.values()), list(sequences))
# intervals = [
# ("Klf2 E1 upstream enhancer", "chr8:72311216-72311616"),
# ("Klf4 E2 upstream enhancer", "chr4:55475488-55475688"),
# ("Prdm14 E3 upstream enhancer", "chr1:13084919-13085299"),
# ("Zfp281 downstream enhancer", "chr1:136680205-136680605"),
# ("Lefty1 upstream", "chr1:180924752-180925152"),
# ("Oct4 distal enhancer", "chr17:35504453-35504603")
# ]
df_enhancers = pd.read_csv("https://docs.google.com/spreadsheets/d/1nIRLv3tWq_3BjorP_pEAyJKWtVClX6_OrE2waN94ECc/export?gid=0&format=csv")
df_enhancer_intervals = df_enhancers[['Name', 'mm10 coordinates']].dropna()
intervals = [(row.Name, str(row[['mm10 coordinates']].iloc[0]).strip())
for i, row in df_enhancer_intervals.iterrows()]
intervals = [
('Nanog upstream enhancer', 'chr6:122707295-122,707,721'),
('Fbxo15 enhancer', 'chr18:84934293-84934692'),
('Nr0b1 enhancer', 'chrX:86187475-86187504')
]
intervals = [('Nanog_EMSA_1', 'chr13:3712945-3712985'),
('Nanog_EMSA_2', 'chr19:21785406-21785446'),
('Nanog_EMSA_3', 'chr4:40856724-40856764'),
('Nanog_EMSA_4', 'chr6:112885434-112885474'),
('Nanog_EMSA_5', 'chr5:142415570-142415610'),
('Tbx3_distal_1', 'chr5:119579664-119580351'),
('Tbx3_distal_2', 'chr5:119579408-119579566'),
('Tbx3_distal_3', 'chr5:119584867-119585462'),
('dsp_distal', 'chr13:38109043-38109501'),
('dsp_proximal', 'chr13:38123730-38124166'),
('cdh1_inragenic', 'chr8:106609099-106609356')
]
!mkdir -p {figures}/known_enhancer_profiles/all-motifs-individual-y-scale-profile
# Generate the right colors
colors = []
for task in tasks:
colors.append((tf_colors[task], tf_colors[task] + "80")) # 80 add alpha=0.5
colors.append(None)
sequences = dict()
for i, (name, interval_str) in enumerate(intervals):
print(f"{i}/{len(intervals)}", name)
interval = parse_interval(interval_str)
interval = resize_interval(interval, 1000)
viz_dict, seq, imp_scores = interval_predict(bpnet, ds, interval, tasks)
dfim = get_instances(patterns, seq, imp_scores, imp_score, centroid_seqlet_matches, motifs, tasks).query('match_weighted_p > .2')
seqlets = dfi2seqlets(dfim, motifs_inv)
xlim = None
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title=str_interval(interval, xlim) + name,
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=2)[0],
use_spine_subset=True,
seqlets=seqlets,
color=colors,
ylim=get_ylim(viz_dict, tasks, True),
# seqlet_plot_fn=plot_seqlet_underscore,
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs-individual-y-scale-profile/{name},{interval_str}.1kb.pdf")
plt.close()
# Figure out the most interesting 150 bp in the entire 1kb region
# by looking at the total number of counts in the 150 bp window
from basepair.preproc import moving_average
# center = np.argmax(moving_average(sum([np.abs(viz_dict[f'{task} Obs']).sum(axis=-1) for task in tasks]), 150))
center = np.argmax(moving_average(sum([np.abs(viz_dict[f'{task} Imp profile']).sum(axis=-1) for task in tasks]), 150))
center = min(max(center, 75), 925)
xlim = [center - 75, center + 75]
seqlets2 = [s.shift(-xlim[0]) for s in seqlets]
viz_dict2 = filter_tracks(viz_dict, xlim)
print(str_interval(interval, xlim), name)
seq_str = one_hot2string(seq[:, slice(*xlim)], DNA)[0]
print(seq_str) # Get the sequence
sequences[f"{name},{interval_str}"] = seq_str
fig = plot_tracks(viz_dict2,
#seqlets=shifted_seqlets,
title=str_interval(interval, xlim) + " (" + str(xlim) + ")",
fig_height_per_track=0.5,
rotate_y=0,
fig_width=get_figsize(frac=1)[0],
use_spine_subset=True,
seqlets=seqlets2,
color=colors,
ylim=get_ylim(viz_dict2, tasks, True),
legend=False)
sns.despine(top=True, right=True, bottom=True)
fig.savefig(f"{figures}/known_enhancer_profiles/all-motifs-individual-y-scale-profile/{name},{interval_str}.150bp.pdf")
plt.close()
# Write out the fasta file
from concise.utils.fasta import write_fasta
write_fasta(f"{figures}/known_enhancer_profiles/all-motifs-individual-y-scale-profile/sequences.150bp.new.fa", list(sequences.values()), list(sequences))