from uuid import uuid4
from collections import OrderedDict
from basepair.math import mean
from basepair.stats import perc
from IPython.display import display, HTML
from basepair.plot.vdom import df2html, df2html_old, render_datatable
from basepair.modisco.core import patterns_to_df
from basepair.modisco.utils import longer_pattern, shorten_pattern, extract_name_short
from basepair.imports import *
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile")
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
patterns = [mr.get_pattern(p) for p in mr.patterns()]
tasks = [x.split("/")[0] for x in mr.tasks()]
# read the property table
pattern_table = pd.read_csv(output_dir / "pattern_table.csv")
# read the footprints
footprints = read_pkl(output_dir / 'footprints.pkl')
pattern_table.head()
patterns = [p.add_profile(footprints[p.name]) for p in patterns]
(Motif heterogeneity, co-occurrence)
from basepair.modisco.motif_clustering import to_colors, preproc_motif_table, motif_table_long, scale, preproc_df
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix
# TODO - speedup using IntervalIndex .. .loc[motif_center...]
# specify motifs to use in the analysis
motifs = OrderedDict([
("Oct4-Sox2", "m0_p0"),
("Oct4-Sox2-deg", "m6_p8"),
("Oct4", "m0_p18"),
("Sox2", "m0_p1"),
("Essrb", "m0_p2"),
("Nanog", "m0_p3"),
("Nanog-periodic", "m0_p9"),
("Klf4", "m2_p0"),
])
dfi = load_instances(f"{modisco_dir}/instances.parq", motifs)
dfi = filter_nonoverlapping_intervals(dfi)
total_examples = len(dfi.example_idx.unique())
total_examples
# Read seqlets from modisco
df_seqlets = mr.seqlet_df_instances(trim_frac=0.08)
df_seqlets = df_seqlets.rename(columns=dict(seqname="example_idx", start='seqlet_start', end='seqlet_end', strand='seqlet_strand', center='seqlet_center'))
df_seqlets['pattern'] = df_seqlets['pattern'].map({k: shorten_pattern(k) for k in df_seqlets.pattern.unique()})
del df_seqlets['name']
def non_overlapping(x, df_seqlets):
"""
Args:
x: scanned motif instances in the region for a particular pattern in the particular region
df_seqlets: pd.DataFrame of Seqlets.
Columns: example_idx, pattern, seqlet_{start, end, strand, center}
"""
if len(x) == 0:
return x
# for each motif instance in the region, find the same instances of the pattern
dfs = df_seqlets[(df_seqlets.pattern == x['pattern_short'].iloc[0]) & (df_seqlets.example_idx == x['example_idx'].iloc[0])]
if len(dfs)==0:
return x
dist = np.abs(x['pattern_center'].values.reshape((-1,1)) - dfs.seqlet_center.values.reshape([1,-1])).min(1)
return x[dist > 15]
df_seqlets.head()
dfi_subset = dfi.query('match_weighted_cat!="low"').query('imp_weighted_cat=="high"')
# get rid of overlapping intervals
dfi_subset_subset = dfi_subset.groupby(['pattern_short', 'example_idx']).apply(lambda x: non_overlapping(x, df_seqlets))
counts = pd.pivot_table(dfi_subset_subset, 'pattern_len', "example_idx", "pattern_name", aggfunc=len, fill_value=0)
c = counts >0 # True or False per interval per pattern
df_seqlets_c = c.merge(df_seqlets, on='example_idx')
df_seqlets_c.head(2)
df_seqlets_c_agg = df_seqlets_c.groupby(['pattern'])[list(c)].mean()
df_seqlets_c_agg_size = df_seqlets_c.groupby(['pattern']).size()
df_seqlets_c_agg = pd.concat([df_seqlets_c_agg, df_seqlets_c_agg_size], axis=1).reset_index().set_index('pattern').rename(columns={0: 'n_seqlets'})
# sort using the classical names
df_seqlets_c_agg = pd.concat([df_seqlets_c_agg, pd.DataFrame(df_seqlets_c_agg.index.to_series().apply(extract_name_short).tolist(), index=df_seqlets_c_agg.index)], axis=1)
df_seqlets_c_agg = df_seqlets_c_agg.sort_values(['metacluster', 'pattern'])
del df_seqlets_c_agg['metacluster']
del df_seqlets_c_agg['pattern']
col_anno = df_seqlets_c_agg[['n_seqlets']]
del df_seqlets_c_agg['n_seqlets']
# Add other motifs
patterns = [p.add_attr('motif_freq', dict(df_seqlets_c_agg.loc[shorten_pattern(p.name)])) for p in patterns]
# patterns = read_pkl(output_dir / 'patterns.pkl')
patterns = [p.add_attr('features', OrderedDict(pattern_table[pattern_table.pattern == shorten_pattern(p.name)].iloc[0])) for p in patterns]
# check that the pattern names match
assert patterns[4].attrs['features']['pattern'] == shorten_pattern(patterns[4].name)
write_pkl(patterns, output_dir / 'patterns.pkl')