from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'
motifs = OrderedDict([
("Oct4-Sox2", 'Oct4/m0_p0'),
("Oct4", 'Oct4/m0_p1'),
# ("Strange-sym-motif", 'Oct4/m0_p5'),
("Sox2", 'Sox2/m0_p1'),
("Nanog", 'Nanog/m0_p1'),
("Zic3", 'Nanog/m0_p2'),
("Nanog-partner", 'Nanog/m0_p4'),
("Klf4", 'Klf4/m0_p0'),
])
align_instance_center# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
from basepair.exp.paper.config import *
paper_config()
%matplotlib inline
model_dir = models_dir / exp
modisco_dir = model_dir / f'deeplift/Nanog/out/{imp_score}'
from basepair.modisco.pattern_instances import load_instances, align_instance_center, filter_nonoverlapping_intervals
aligned_patterns = []
all_patterns = []
for task in tasks:
patterns = read_pkl(model_dir / f'deeplift/{task}/out/{imp_score}/patterns.pkl')
patterns = [p.rename(task + "/" + shorten_patten(p.name)) for p in patterns]
all_patterns += patterns
aligned_patterns = read_pkl(modisco_dir / 'patterns.pkl')
aligned_patterns[3].plot('seq_ic');
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
from basepair.cli.imp_score import ImpScoreFile
from kipoi.data_utils import get_dataset_item
# load the contrib scores
imp = ImpScoreFile.from_modisco_dir(modisco_dir)
contrib_scores = imp.get_contrib()
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
orig_patterns_d = {p.name: p for p in orig_patterns}
dfi = load_instances(modisco_dir / 'instances.parq', motifs=None, dedup=False)
dfi.head()
dfi = align_instance_center(dfi, orig_patterns, aligned_patterns, trim_frac=0.08)
# remove all the overlapping intervals
dfi = filter_nonoverlapping_intervals(dfi)
dfi.eval("strand == pattern_strand_aln").mean()
pattern_strand_aln and the same pattern_center_alnsizes = dfi.groupby('example_idx').size()
sizes.plot.hist(30)
plt.xlabel("Number of instances per example")
dfi_filt = dfi[(dfi.imp_weighted_cat == 'high') & (dfi.match_weighted_cat != 'low')].set_index('example_idx')
dfi_filt.groupby('example_idx').size().plot.hist(30)
plt.xlabel("Number of instances per example");
def norm_matrix(s):
"""Create the normalization matrix
Example:
print(norm_matrix(pd.Series([1,3,5])).to_string())
0 1 2
0 1 1 1
1 3 3 3
2 5 5 5
Args:
s: pandas series
"""
tnc = s.values[:, np.newaxis]
vals_by_row = tnc * np.ones_like(tnc).T
# np.fill_diagonal(vals_by_row, 1)
return pd.DataFrame(vals_by_row, index=s.index, columns=s.index)
# exclude noisy motifs
exclude = ['metacluster_2/pattern_22', 'metacluster_1/pattern_16', 'metacluster_6/pattern_9', 'metacluster_6/pattern_12']
dfi_filt = dfi_filt[~dfi_filt.pattern.isin(exclude)]
# normalization: minimum number of counts
total_number = dfi_filt.groupby(['pattern']).size()
norm_counts = norm_matrix(total_number)
# cross-product
dfi_filt_crossp = pd.merge(dfi_filt[['pattern', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center']],
dfi_filt[['pattern', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center']],
how='outer', left_index=True, right_index=True).reset_index()
# remove self-matches
dfi_filt_crossp = dfi_filt_crossp.query('~((pattern_x == pattern_y) & (pattern_center_aln_x == pattern_center_aln_y) & (pattern_strand_aln_x == pattern_strand_aln_x))')
# order the matrix by names
idx_order = [p.name for p in aligned_patterns if p.name in dfi_filt.pattern.unique()]
norm_counts.head()
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) == 0) & (pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
%opts HeatMap [xrotation=90] (cmap='Blues')
norm_count_matrix[idx_order].loc[idx_order].stack().reset_index().hvplot.heatmap(x='level_1', y='level_0', C="0", width=1000, height=1000, colorbar=True) # TODO add vmax
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(norm_count_matrix[idx_order].loc[idx_order], ax=ax, cmap='Blues', cbar_kws={'label': 'Overlap fraction: #A==B / min(#A, #B)'}, vmax=1)
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) <= 5) & (abs(pattern_center_aln_x- pattern_center_aln_y) != 0) &(pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(norm_count_matrix[idx_order].loc[idx_order], ax=ax, cmap='Blues', cbar_kws={'label': 'Overlap fraction: #A==B / min(#A, #B)'}, vmax=1)
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) < 15) & (abs(pattern_center_aln_x- pattern_center_aln_y) > 5) &(pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
%opts HeatMap [xrotation=90] (cmap='Blues')
norm_count_matrix[idx_order].loc[idx_order].stack().reset_index().hvplot.heatmap(x='level_1', y='level_0', C="0", width=1000, height=1000, colorbar=True) # TODO add vmax
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) < 50) & (abs(pattern_center_aln_x- pattern_center_aln_y) > 16) &(pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
%opts HeatMap [xrotation=90] (cmap='Blues')
norm_count_matrix[idx_order].loc[idx_order].stack().reset_index().hvplot.heatmap(x='pattern_y', y='pattern_x', C="0", width=1000, height=1000, colorbar=True) # TODO add vmax
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) < 100) & (abs(pattern_center_aln_x- pattern_center_aln_y) > 50) &(pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
%opts HeatMap [xrotation=90] (cmap='Blues')
norm_count_matrix[idx_order].loc[idx_order].stack().reset_index().hvplot.heatmap(x='pattern_y', y='pattern_x', C="0", width=1000, height=1000, colorbar=True) # TODO add vmax
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(norm_count_matrix[idx_order].loc[idx_order], ax=ax, cmap='Blues', cbar_kws={'label': 'Overlap fraction: #A==B / min(#A, #B)'}, vmax=1)
abs_distance < 5 or abs_distance != 5 we get almost the same contraintsdfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) < 50) & (abs(pattern_center_aln_x- pattern_center_aln_y) != 0) &(pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
# Excluding exact matches
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(norm_count_matrix[idx_order].loc[idx_order], ax=ax, cmap='Blues', cbar_kws={'label': 'Overlap fraction: #A==B / min(#A, #B)'}, vmax=1)
out = []
for i in range(35):
dfi_sp = dfi_filt_crossp.query(f'(abs(pattern_center_aln_x- pattern_center_aln_y) == {i})')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
out.append(norm_count_matrix)
# Make a gif
vmax = 0.1
for i in range(len(count_matrices)):
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(count_matrices[i][idx_order].loc[idx_order], ax=ax, cmap='Blues', vmax=1)
ax.set_title(i)
fig.savefig(f"{modisco_dir}/spacing/heatmap_gif/figures/{i:02d}.png", transparent=False)
plt.close()
# make a gif
!convert -delay 100 -loop 0 {modisco_dir}/spacing/heatmap_gif/figures/*.png {modisco_dir}/spacing/heatmap_gif/heatmap.gif
def interactive_heatmap(count_matrices, idx_order, vmax=.5):
def build_fn(count_matrices, idx_order, vmax):
def fn(i):
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(count_matrices[i][idx_order].loc[idx_order]*1000, ax=ax, cmap='Blues', vmax=vmax)
return fn
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
return interactive(build_fn(count_matrices, idx_order, vmax),
i=widgets.IntSlider(min=0,
max=len(count_matrices) - 1,
step=1,
value=0))
mkdir -p {modisco_dir}/spacing/heatmap_gif/figures
from IPython.display import display
from vdom.helpers import img
display)
interactive_heatmap(out, idx_order, vmax=.1)
patterns_d = {p.name: p for p in aligned_patterns}
v = (norm_count_matrix*1000 > .1).stack()
for a,b in v[v].index:
patterns_d[a].plot('seq_ic', letter_width=.1, height=.5)
patterns_d[b].plot('seq_ic', letter_width=.1, height=.5)
plt.xlabel("Position")
##
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) <= 2) & (pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix_2 = match_sizes.unstack(fill_value=0)
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) <= 0) & (pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix_0 = match_sizes.unstack(fill_value=0)
v = (count_matrix_2 - count_matrix_0).stack().sort_values().tail(10)
print(v.to_string())
for a,b in v.index:
patterns_d[a].plot('seq_ic', letter_width=.1, height=.5)
patterns_d[b].plot('seq_ic', letter_width=.1, height=.5)
plt.xlabel("Position")
patterns_d['metacluster_2/pattern_2'].plot("seq_ic");
patterns_d['metacluster_3/pattern_3'].plot("seq_ic");
metacluster_2/pattern_22
patterns_d['metacluster_2/pattern_22'].plot("seq_ic");
exclude
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) <= 20) & (pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix.truediv(total_number, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
# order the matrix by names
idx_order = [p.name for p in aligned_patterns]
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(norm_count_matrix[idx_order].loc[idx_order]*1000, ax=ax, cmap='Blues', vmax=.5)
# TODO - remove the
# m0_p0
# m0_p1
dfi_filt_crossp.head()
dfi_filt_crossp.query('(pattern_x == "metacluster_0/pattern_0") & (pattern_y == "metacluster_0/pattern_1")')
dfos = dfi_filt_crossp.query('(pattern_x == "metacluster_2/pattern_0") & (pattern_y == "metacluster_2/pattern_0")')
dfos['distance'] = dfos.pattern_center_aln_x - dfos.pattern_center_aln_y
dfos['old_distance'] = dfos.pattern_center_x - dfos.pattern_center_y
dfos['abs_distance'] = np.abs(dfos.distance)
dfos[dfos.abs_distance < 50]['distance'].plot.hist(100)
dfos[dfos.abs_distance < 50]['old_distance'].plot.hist(100)
dfos[dfos.abs_distance < 20].distance.value_counts()
dfos[dfos.distance==3]
dfos[dfos.distance==-3]
# The central shift is due to the repetative elements
plot_tracks(filter_tracks(get_dataset_item(contrib_scores, 103), [400, 600]));
patterns_d['metacluster_0/pattern_0'].plot("seq_ic");
patterns_d['metacluster_0/pattern_1'].plot("seq_ic");
orig_patterns_d['metacluster_1/pattern_0'].plot("seq_ic");
orig_patterns_d['metacluster_1/pattern_1'].plot("seq_ic");
orig_patterns_d['metacluster_1/pattern_0'].trim_seq_ic(trim_frac=0.08).plot("seq_ic");
orig_patterns_d['metacluster_1/pattern_1'].trim_seq_ic(trim_frac=0.08).plot("seq_ic");
len(orig_patterns_d['metacluster_1/pattern_1'].trim_seq_ic(trim_frac=0.08))
orig_patterns_d['metacluster_1/pattern_1']._trim_seq_ic_ij(trim_frac=0.08)
orig_patterns_d['metacluster_1/pattern_1']._trim_center_shift(trim_frac=0.08)
70//2
(39 - 28) // 2
patterns_d['metacluster_1/pattern_0'].attrs['align']
patterns_d['metacluster_1/pattern_1'].attrs['align']
fig ,ax = plt.subplots(figsize=(15,2))
dfos.distance.plot.hist(1000, ax=ax);
dfi_sp = dfi_filt_crossp.query('(abs(pattern_center_aln_x- pattern_center_aln_y) <= 20) & (pattern_strand_aln_x == pattern_strand_aln_x)')
match_sizes = dfi_sp.groupby(['pattern_x', 'pattern_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix.truediv(total_number, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
v = (norm_count_matrix*1000 > .2).stack()
for a,b in v[v].index:
patterns_d[a].plot('seq_ic', letter_width=.1, height=.5)
patterns_d[b].plot('seq_ic', letter_width=.1, height=.5)
plt.xlabel("Position")
unique_instances¶unique_instances should allow only a single position to be occupied by motifs from the same model group (?)