%matplotlib inline
from uuid import uuid4
from collections import OrderedDict
from kipoi.utils import unique_list
from basepair.math import mean
from basepair.stats import perc
from IPython.display import display, HTML
from basepair.plot.vdom import df2html, df2html_old, render_datatable
from basepair.modisco.core import patterns_to_df
from basepair.modisco.utils import longer_pattern, shorten_pattern, extract_name_long
from basepair.imports import *
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
output_dir = modisco_dir
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
tasks = [x.split("/")[0] for x in mr.tasks()]
patterns = read_pkl(output_dir / 'patterns.pkl')
'align': {'use_rc': False, 'offset': 7},import holoviews as hv
hv.extension('bokeh')
from basepair.modisco.motif_clustering import to_colors, preproc_motif_table, motif_table_long, scale, preproc_df
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix
# specify motifs to use in the analysis
motifs = OrderedDict([
("Oct4-Sox2", "m0_p0"),
("Oct4-Sox2-deg", "m6_p8"),
("Oct4", "m0_p18"),
("Sox2", "m0_p1"),
("Essrb", "m0_p2"),
("Nanog", "m0_p3"),
("Nanog-periodic", "m0_p9"),
("Klf4", "m2_p0"),
])
# Load instances
#
dfi = load_instances(f"{modisco_dir}/instances.parq", motifs, dedup=False)
dfi.set_index("example_idx", inplace=True)
# load seqlets
df_seqlets = mr.seqlet_df_instances(trim_frac=0.08)
df_seqlets = df_seqlets.rename(columns={"seqname": "example_idx"}).set_index('example_idx')
dfi['pattern_strand'] = dfi['strand']
dfi.head()
df_seqlets.head()
# create a table with all possible pairs
raw_patterns = [mr.get_pattern(pn).trim_seq_ic(0.08).pad(70) for pn in mr.patterns()]
# Get the major motifs
major_motifs = [p for p in raw_patterns if p.name in [longer_pattern(x) for x in motifs.values()]]
def get_offset(d):
if d['use_rc']:
return - d['offset']
else:
return d['offset']
# align seqlet -> pattern and get the diff
shifts = pd.DataFrame([{"seqlet_pattern": p_seqlet.name,
"pattern": p_instance.name,
"offset": get_offset(p_seqlet.align(p_instance).attrs['align'])}
for p_seqlet in raw_patterns
for p_instance in major_motifs])
df_seqlets['seqlet_idx'] = np.arange(len(df_seqlets))
df_seqlets.head()
def seqlet_minstance_coocurrence(df_seqlets, dfi, min_dist=15, max_dist=150):
"""Compute the co-occurence of the seqlets the motif instances
Args:
df_seqlets: pd.DataFrame returned by mr.seqlet_df_instances(trim_frac=0.08) with example_idx index
dfi: pd.DataFrame loaded from the parquet file with example_idx index
min_dist: minimum distance between the seqlet and minstance
max_dist: maxium distance between the seqlet and minstance
Returns:
pd.DataFrame with seqlet pattern name as index and columns 'pattern_name' columns
"""
assert df_seqlets.index.name == 'example_idx'
assert dfi.index.name == 'example_idx'
df_seqlets['seqlet_idx'] = np.arange(len(df_seqlets))
dfi_crossp = pd.merge(df_seqlets[['center', 'pattern', 'strand', 'seqlet_idx']].rename(columns=dict(pattern="seqlet_pattern",
strand='seqlet_strand',
center='seqlet_center')),
dfi[["pattern_center", "pattern", "pattern_name", 'pattern_strand']],
how='outer', left_index=True, right_index=True).reset_index()
# Note that this table can have None-values where one of both is not found
dfi_crossp = pd.merge(dfi_crossp, shifts, on=['seqlet_pattern', 'pattern'], how='left')
# shift the seqlet accordingly
dfi_crossp['shifted_center'] = dfi_crossp.seqlet_center - dfi_crossp.offset * dfi_crossp.seqlet_strand.map({"+": 1, "-": -1})
# TODO - this is wrong - we ditch seqlet instances that don't exist
# consider only valid pairs
dfi_crossp = dfi_crossp[(np.abs(dfi_crossp.shifted_center - dfi_crossp.pattern_center) > min_dist) &
(np.abs(dfi_crossp.shifted_center - dfi_crossp.pattern_center) < max_dist)]
counts = dfi_crossp.pivot_table(index=['seqlet_idx', 'seqlet_pattern'],
columns='pattern_name',
values='shifted_center',
aggfunc=lambda x: int(np.sum(x.isnull()==False) > 0),
fill_value=0)
dfi_crossp['pattern_name_random'] = dfi_crossp.pattern_name.sample(frac=1).values
# TODO - index should be seqlet_idx not example_idx
counts_random = dfi_crossp.pivot_table(index=['seqlet_idx', 'seqlet_pattern'],
columns='pattern_name_random',
values='shifted_center',
aggfunc=lambda x: int(np.sum(x.isnull()==False) > 0),
fill_value=0)
# return counts, counts_random
c = counts.groupby(level=1).sum()
c.columns.name = None
c.index.name = 'pattern'
cr = counts_random.groupby(level=1).sum()
cr.columns.name = None
cr.index.name = 'pattern'
return c, cr
nte_patterns = [p.name for p in patterns if p.attrs['pattern_group'] == 'nte']
match_weighted_cat!="low", imp_weighted_cat=="high", [15, 80]¶Question: does pattern occur significantly more often with seqlet_pattern or not?
c, cr = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=80)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
from scipy.stats import binom_test
n_seqlets = df_seqlets.groupby("pattern").size()
p0 = cr.divide(n_seqlets, axis='rows')
p1 = c.divide(n_seqlets, axis='rows')
dfp = pd.DataFrame([{tf: binom_test(c.loc[pattern, tf], n_seqlets.loc[pattern], p0.loc[pattern, tf], alternative='greater')
for tf in c.columns}
for pattern in c.index], index=c.index)
# Control for multiple testing
dfp = dfp / dfp.size
dfco = p1
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
signif_threshold= 1e-5
signif = dfp.loc[nte_patterns] < signif_threshold
a = np.zeros_like(signif).astype(str)
a[signif] = "*"
a[~signif] = ""
np.fill_diagonal(a, '')
fig, ax = plt.subplots(figsize=(6, 22))
# log_odds = np.log10(dfco / (p0 + 1e-6))
# log_odds= dfco / (p0 + 1e-3)
sns.heatmap((p1 / (p0 + 0.001)).loc[nte_patterns], annot=a, fmt="", vmin=0, vmax=4, center=1, cmap='RdBu_r')
plt.title(f"Log2 odds-ratio. (*: p<{signif_threshold})");
dfi_crossp = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=150)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
dfco, dfco_sum = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=150)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco.loc[nte_patterns].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
match_weighted_cat!="low", imp_weighted_cat=="high", [15, 50]¶dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=50)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco.loc[nte_patterns].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
match_weighted_cat!="low", imp_weighted_cat!="low", [15, 150]¶dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat!="low"'),
min_dist=15, max_dist=150)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco.loc[nte_patterns].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
match_weighted_cat=="high", imp_weighted_cat!="low", [15, 150]¶dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat=="high"').query('imp_weighted_cat!="low"'),
min_dist=15, max_dist=150)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco.loc[nte_patterns].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
match_weighted_cat=="low", imp_weighted_cat=="high", [15, 150]¶dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat=="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=150)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
match_weighted_cat=="low", imp_weighted_cat=="high", [15, 50]¶dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat=="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=50)
# sort by metacluster and pattern
dfco = pd.concat([dfco, pd.DataFrame(pd.Series(dfco.index).apply(extract_name_long).tolist(), index=dfco.index)], axis=1).sort_values(['metacluster', 'pattern'])
del dfco['metacluster']
del dfco['pattern']
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco[dfco.index.isin(nte_patterns)].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
%opts HeatMap [xrotation=90] (cmap='Blues')
dfco.loc[nte_patterns].reset_index().melt(id_vars='pattern').hvplot.heatmap(x='variable', y='pattern', C='value', width=600, height=800, )
dfco = seqlet_minstance_coocurrence(df_seqlets, dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"'),
min_dist=15, max_dist=50)
# standardize
#dfco = (dfco- dfco.mean(0)) / dfco.std(0)
patterns[0].attrs['motif_freq']
def get_values(dfco, pname):
if pname not in dfco.index:
return {c: 0 for c in dfco.columns}
else:
return dict(dfco.loc[pname])
patterns = [p.add_attr('motif_odds', get_values(p1 / (p0 + 0.001), p.name)) for p in patterns]
patterns = [p.add_attr('motif_odds_p', get_values(dfp, p.name)) for p in patterns]
# Write back the pickle file
write_pkl(patterns, output_dir / 'patterns.pkl')
# split the patterns back into the groups
patterns_nte_clustered = [x for x in patterns if x.attrs['pattern_group'] == 'nte']
patterns_te_clustered = [x for x in patterns if x.attrs['pattern_group'] == 'te']
from basepair.modisco.motif_clustering import *
pattern_table_nte_seq.columns
pattern_table_nte_seq = create_pattern_table(patterns_nte_clustered,
logo_len=50,
seqlogo_kwargs=dict(width=420),
n_jobs=20,
footprint_width=120,
footprint_kwargs=dict(figsize=(3,1.5)))
pattern_table_te_seq = create_pattern_table(patterns_te_clustered,
logo_len=70,
seqlogo_kwargs=dict(width=420),
n_jobs=20,
footprint_width=120,
footprint_kwargs=dict(figsize=(3,1.5)))
background_motifs = ['Essrb', 'Klf4', 'Nanog','Oct4','Oct4-Sox2', 'Sox2']
colorder = ['pattern', 'cluster', 'n seqlets', 'logo_imp', 'logo_seq'] + [task+'/f' for task in tasks] + [t + '/d_p' for t in tasks] + [m+'/odds' for m in background_motifs]
(output_dir / 'motif_clustering').mkdir(exist_ok=True)
remove = [task+'/f' for task in tasks] + ['logo_imp', 'logo_seq']
html_str_contrib = df2html(pattern_table_nte_seq[colorder], "table_contrib" + str(uuid4()), "")
with open(output_dir / 'motif_clustering/smaller.seq-cluster.v3.html', 'w') as fo:
fo.write(html_str_contrib)
# HTML(html_str_contrib)
html_str_te_contrib = df2html(pattern_table_te_seq[colorder], "table_seq" + str(uuid4()), "")
with open(output_dir / 'motif_clustering/te-cluster.v3.html', 'w') as fo:
fo.write(html_str_te_contrib)
# html_str_seq
pattern_table_nte_seq['pattern'] = [shorten_pattern(p.name) for p in patterns_nte_clustered]
pattern_table_nte_seq.iloc[:, ~pattern_table_nte_seq.columns.isin(remove)].to_csv(output_dir / 'motif_clustering/smaller.seq-cluster.v3.csv')