In [3]:
from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'

motifs = OrderedDict([
    ("Oct4-Sox2", 'Oct4/m0_p0'),
    ("Oct4", 'Oct4/m0_p1'),
    # ("Strange-sym-motif", 'Oct4/m0_p5'),
    ("Sox2", 'Sox2/m0_p1'),
    ("Nanog", 'Nanog/m0_p1'),
    ("Zic3", 'Nanog/m0_p2'),
    ("Nanog-partner", 'Nanog/m0_p4'),
    ("Klf4", 'Klf4/m0_p0'),
])

Goal

  • cluster all motifs for modisco-based methods

Question

  • [X] how to cluster the motifs
    • have a specific distance threshold
      • figure it out by eye (add to the motif title)
  • [X] Compute the venn diagram for these motifs
    • make also sure that the motifs in these clusters are meaningful
      • e.g. assign a name to the cluster
  • quantify other properties of a modisco run

Which motifs are more consolidated?

In [4]:
from basepair.imports import *
from basepair.exp.paper.config import *
from basepair.utils import flatten
from basepair.modisco.core import Pattern
from plotnine import *
import plotnine
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
In [5]:
from basepair.utils import pd_first_cols, flatten
from basepair.exp.chipnexus.comparison import read_peakxus_dfi, read_chexmix_dfi, read_fimo_dfi, read_meme_motifs, read_transfac
In [154]:
# models = {
#     'nexus/binary': 'nexus,gw,OSNK,1,0,0,FALSE,same,0.5,64,25,0.001,9,FALSE',
#     'nexus/profile': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE',
#     # 'nexus/profile.peaks.bias-corrected.augm': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE',
#     # 'nexus/profile.peaks.non-bias-corrected': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE-2',
#     'seq/binary': 'seq,gw,OSN,1,0,0,FALSE,same,0.5,64,50,0.001,9,FALSE',
#     #'seq/profile': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE',
#     #'seq/profile.peaks.bias-corrected.augm': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE,TRUE',
#     'seq/profile': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE,TRUE',
#     # 'seq/profile.peaks.non-bias-corrected': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE',
#     'nexus/profile.peaks-union': 'nexus,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE',
#     'seq/profile.peaks-union': 'seq,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE'
# }
models = {
    'nexus/binary': 'nexus,gw,OSNK,1,0,0,FALSE,same,0.5,64,25,0.001,9,FALSE',
    'nexus/profile': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE',
    # 'nexus/binary+profile': 'nexus,gw,OSNK,1,0.1,0.01,FALSE,same,0.5,64,25,0.001,9,FALSE',
    # 'nexus/profile.gw': 'nexus,gw,OSNK,0,10,1,FALSE,same,0.5,64,25,0.001,9,FALSE',
    # 'nexus/profile.peaks.bias-corrected.augm': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE',
    # 'nexus/profile.peaks.non-bias-corrected': 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE-2',
    'seq/profile': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE,TRUE',
    'seq/binary': 'seq,gw,OSN,1,0,0,FALSE,same,0.5,64,50,0.001,9,FALSE',
    #'seq/profile': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE',
    #'seq/profile.peaks.bias-corrected.augm': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE,TRUE',
    
    # 'seq/profile.peaks.non-bias-corrected': 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE',
    # 'nexus/profile.peaks-union': 'nexus,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE,0',
    # 'nexus/profile.peaks-union.no-aug': 'nexus,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,FALSE,0',
    'nexus/profile.peaks-union': 'nexus,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE,TRUE,0',
    'seq/profile.peaks-union': 'seq,nexus-seq-union,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE,TRUE,0',
#     'basset/binary': 'binary-basset,nexus,gw,OSNK,0.5,64,0.001,FALSE,0.6',
#     'factorized-basset/binary': 'factorized-basset,nexus,gw,OSNK,0.5,64,0.001,FALSE,0.5'
}
models_inv = {v:k for k,v in models.items()}
In [155]:
# TODO - add the nexus/profile.peaksunio
In [156]:
from config import experiments
In [157]:
fdir = Path(f'{ddir}/figures/method-comparison/modisco')
fdir.mkdir(exist_ok=True)
In [158]:
from basepair.exp.paper.config import get_tasks
In [159]:
def load_stats(exp, model):
    mr = MultipleModiscoResult({t: models_dir / exp / f'deeplift/{t}/out/{experiments[exp]["imp_score"]}/modisco.h5'
                                for t in get_tasks(exp)})
    return pd.DataFrame([{"task": task, "model": model, **mr2.stats(verbose=False)}
            for task,mr2 in mr.mr_dict.items()
           ])
In [160]:
# TODO - add seq/binary to experiments
In [148]:
dfs = pd.concat([load_stats(exp, model) for model, exp in models.items() if exp in experiments])
dfs['seqlets_per_pattern'] = dfs['clustered_seqlets'] / dfs['patterns']
In [164]:
plotnine.options.figure_size = get_figsize(.4)
fig = (ggplot(aes(x='task', fill='model', y='clustered_seqlets_frac'), dfs) + 
 geom_bar(stat='identity', position='dodge') + 
 theme_classic() + 
 scale_fill_brewer(type='qual', palette=3)
)
fig.save(fdir / 'modisco-stats.clustered_seqlets_frac.pdf')
fig
Out[164]:
<ggplot: (8774878491582)>
In [165]:
fig = (ggplot(aes(x='task', fill='model', y='all_seqlets'), dfs) + 
 geom_bar(stat='identity', position='dodge') + 
 theme_classic() + 
 scale_fill_brewer(type='qual', palette=3)
)
fig.save(fdir / 'modisco-stats.n_all_seqlets.pdf')
fig
Out[165]:
<ggplot: (8774876419567)>
In [166]:
fig = (ggplot(aes(x='task', fill='model', y='clustered_seqlets'), dfs) + 
 geom_bar(stat='identity', position='dodge') + 
 theme_classic() + 
 scale_fill_brewer(type='qual', palette=3)
)
fig.save(fdir / 'modisco-stats.clustered_seqlets.pdf')
fig
Out[166]:
<ggplot: (-9223363261976868092)>
In [168]:
fig = (ggplot(aes(x='task', fill='model', y='patterns'), dfs) + 
 geom_bar(stat='identity', position='dodge') + 
 theme_classic() + 
 scale_fill_brewer(type='qual', palette=3)
)
fig.save(fdir / 'modisco-stats.patterns.pdf')
fig
Out[168]:
<ggplot: (-9223363261978677851)>
In [169]:
fig = (ggplot(aes(x='task', fill='model', y='seqlets_per_pattern'), dfs) + 
 geom_bar(stat='identity', position='dodge') + 
 theme_classic() + 
 scale_fill_brewer(type='qual', palette=3)
)
fig.save(fdir / 'modisco-stats.seqlets_per_pattern.pdf')
fig
Out[169]:
<ggplot: (8774876555079)>
In [170]:
comparison_dir = Path('../../chipnexus/comparison/output')
In [171]:
# main_motifs = [mr.get_pattern(pattern_name).rename(name)
#                for name, pattern_name in motifs.items()]
In [172]:
def rename(name):
    task, mc, pattern = name.split("/")
    mc_n = mc.split("_")[1]
    pattern_n = pattern.split("_")[1]
    return f'BPNet/{task}/m{mc_n}_p{pattern_n}'
In [173]:
def load_patterns(exp, model):
    mr = MultipleModiscoResult({t: models_dir / exp / f'deeplift/{t}/out/{experiments[exp]["imp_score"]}/modisco.h5'
                                for t in get_tasks(exp)})
    return [p.rename(model + '/' + p.name).add_attr("n_seqlets", mr.n_seqlets(p.name))
            for p in mr.get_all_patterns()
            if mr.n_seqlets(p.name) > 100]
In [174]:
# bpnet_patterns = [p.rename(rename(p.name)).trim_seq_ic(0.08) for p in mr.get_all_patterns()
#                   if mr.n_seqlets(p.name) > 100]

All motifs

In [18]:
all_patterns = []
for model, exp in models.items():
    try:
        all_patterns += load_patterns(exp, model)
    except:
        print("Unable to load ", model, exp)
In [19]:
from basepair.exp.paper.fig4 import cluster_align_patterns
In [20]:
len(all_patterns)
Out[20]:
264
In [21]:
# Cluster patterns
from basepair.exp.paper.fig4 import *
patterns = all_patterns
n_clusters = 20
sim_seq = similarity_matrix(patterns, track='seq_ic')

# cluster
iu = np.triu_indices(len(sim_seq), 1)
lm_seq = linkage(1 - sim_seq[iu], 'ward', optimal_ordering=True)

# determine clusters and the cluster order
cluster = cut_tree(lm_seq, n_clusters=n_clusters)[:, 0]
cluster_order = np.argsort(leaves_list(lm_seq))

# align patterns
aligned_patterns = align_clustered_patterns(patterns, cluster_order, cluster,
                                              align_track='seq_ic',
                                              metric='continousjaccard',
                                              # don't shit the major patterns
                                              # by more than 15 when aligning
                                              trials=20,
                                              max_shift=15)
100%|██████████| 264/264 [04:29<00:00,  1.01s/it]
Number of seqlets not provided, using random number of seqlets
100%|██████████| 264/264 [00:01<00:00, 245.96it/s]
In [22]:
write_pkl(aligned_patterns, 
          "/users/avsec/gdrive/projects/chipnexus/data/motif_clustering.patterns.pkl")
In [23]:
from basepair.modisco.motif_clustering import to_colors
In [24]:
def shorten_name(pn):
    m1, m2, tf, mc, p = pn.split("/")
    mcn = mc.split("_")[1]
    pn = p.split("_")[1]
    return f"{m1}/{m2}/{tf}/m{mcn}_p{pn}"
In [25]:
pnames = [shorten_name(p.name) for p in all_patterns]
In [26]:
df_anno = pd.DataFrame(dict(method=["/".join(pn.split("/")[:2]) for pn in pnames],
                            tf=[pn.split("/")[2] for pn in pnames]),
                       index=pnames)
In [27]:
# Alternative code
import matplotlib.colors
df = df_anno
cat_cmap='RdBu_r'
import matplotlib.pyplot as plt
import seaborn as sns
cols = {}
for c in df.columns:
    categories = list(df[c].unique())
    if c == 'tf':
        # cmap_dict = {tf: matplotlib.colors.to_rgb(hex_color)
        #             for tf, hex_color in tf_colors.items()}
        cmap_dict = tf_colors
    else:
        cmap = sns.color_palette(cat_cmap, len(categories))
        cmap_dict = dict(zip(categories, cmap))
    cols[c] = df[c].map(cmap_dict)
row_colors = pd.DataFrame(cols)
In [28]:
c = 'method'
categories = list(df[c].unique())
sns.color_palette(cat_cmap, len(categories))
cmap_dict_method = dict(zip(list(df[c].unique()), cmap))
In [29]:
from basepair.plot.utils import plot_colormap
In [30]:
plot_colormap(cmap_dict_method, "Method")
# plt.savefig(fdir / 'cmap.method.pdf')
In [31]:
plot_colormap({k:v for k,v in tf_colors.items() if k !='Esrrb'}, 'TF')
# plt.savefig(fdir / 'cmap.TF.pdf')
In [32]:
df = pd.DataFrame(sim_seq, 
                  index=pnames,
                  columns=pnames)
cm = sns.clustermap(df, 
                    row_linkage=lm_seq,
                    col_linkage=lm_seq,
                    row_colors=row_colors,
                    cmap='Blues', figsize=get_figsize(1, 1))
plt.savefig(fdir / 'clustered-motifs.heatmap.pdf')
In [33]:
# aligned_patterns
In [34]:
cluster20 = cut_tree(lm_seq, n_clusters=20)[:, 0]
cluster30 = cut_tree(lm_seq, n_clusters=30)[:, 0]
cluster40 = cut_tree(lm_seq, n_clusters=40)[:, 0]
cluster50 = cut_tree(lm_seq, n_clusters=50)[:, 0]
cluster60 = cut_tree(lm_seq, n_clusters=60)[:, 0]
cluster_order = np.argsort(leaves_list(lm_seq))
In [35]:
dfp = pd.DataFrame({"name": pnames,
              "cluster20": cut_tree(lm_seq, n_clusters=20)[:, 0],
              "cluster30": cut_tree(lm_seq, n_clusters=30)[:, 0],
              "cluster40": cut_tree(lm_seq, n_clusters=40)[:, 0],
              "cluster50": cut_tree(lm_seq, n_clusters=50)[:, 0],
              "cluster60": cut_tree(lm_seq, n_clusters=60)[:, 0],
             }).set_index('name').loc[np.array([shorten_name(p.name) for p in aligned_patterns])].reset_index()
dfp.to_csv("/users/avsec/gdrive/projects/chipnexus/data/motif_clustering.csv", index=None)

Load the clustered table

In [196]:
 
Out[196]:
0            TE1
1            TE1
2            TE1
         ...    
261    Klf4-long
262    Klf4-long
263    Klf4-long
Name: label, Length: 264, dtype: object
In [197]:
# Create the summary table
df_n_seqlets = pd.DataFrame({
    "name": [shorten_name(p.name) for p in all_patterns],
    "n_seqlets": [p.attrs['n_seqlets'] for p in all_patterns],
})

# dfp = pd.read_csv("https://docs.google.com/spreadsheets/d/1enDvKamBZOs4-EIKy--zUebekHvhJekAmNy1yExf1VQ/export?gid=481460659&format=csv")
# dfp = pd.read_csv("https://docs.google.com/spreadsheets/d/1umWAMk-FKN2u4k8noKoaX3FZXmU3_-IurnO922cKUko/export?gid=1666626001&format=csv")
dfp = pd.read_csv("https://docs.google.com/spreadsheets/d/1G_vC2Y9j45WuN08Erzm8M57uNWsXl4Obl9WR2zKA6dc/export?gid=47535911&format=csv")
dfp['label'] = dfp.label.str.replace('"', '')
dfp = pd.merge(dfp, df_n_seqlets, on='name')
dfp['method'] = dfp.name.str.rsplit("/", n=2, expand=True)[0]
dfpm = pd.pivot_table(dfp, values='n_seqlets', index='label', columns='method', aggfunc=np.sum, fill_value=0)
In [198]:
print(dfp.groupby(['label', 'erronious']).size().to_string())
label              erronious
B-Box              False         4
CG                 True          3
Essrb              False        14
Klf4               False        12
Klf4-long          False         5
Nanog              False         9
Nanog-alt          False         7
Oct4               False         5
Oct4-Oct4          False         8
Oct4-Sox2          False        20
Sox2               False        23
TE1                False         5
TE10               False         2
TE11               False         1
TE12               False         1
TE13               False         1
TE14               False         1
TE15               False         3
TE16               False         3
TE17               False         5
TE18               False         5
TE2                False        12
TE20               False         3
TE21               False         6
TE22               False        14
TE23               False         1
TE24               False        10
TE25               False         8
TE26               False         6
TE27               False         4
TE3                False         1
TE4                False         1
TE5                False         1
TE6                False         2
TE7                False         1
TE8                False         3
TE9                False         2
Zic3               False        15
deg. Oct4          False         1
deg. Oct4-Sox2     False         6
deg. Oct4-Sox2 v2  False         7
noisy Nanog        True          1
noisy Nanog-alt    True          1
noisy Oct4-Sox2    True          2
noisy Sox2         True          3
noisy deg. Oct4    True          2
noisy deg. Sox2    True          1
poly-A             True          1
random             True         12
In [199]:
noisy_motifs = list(dfpm.index.unique()[dfpm.index.unique().str.startswith("noisy")])
In [200]:
erronious = dfp.query('erronious').label.unique()
In [201]:
dfpm = dfpm.loc[dfpm.sum(1).sort_values(ascending=False).index]
dfpm = pd.concat([
    dfpm[~dfpm.index.isin(erronious)],
    dfpm.loc[erronious]
])
In [202]:
def remove_zeroes(df):
    df = df[df.max(axis=1)!=0]
    dfpm = df.loc[df.sum(1).sort_values(ascending=False).index]
    dfpm = pd.concat([
        dfpm[~dfpm.index.isin(erronious)],
        dfpm[dfpm.index.isin(erronious)]
    ])
    return dfpm
In [203]:
# All motifs
fig, ax = plt.subplots(figsize=get_figsize(1.6, 1/10))
sns.heatmap(np.log10(remove_zeroes(dfpm[['nexus/profile', 'nexus/binary',
                           'seq/profile', 'seq/binary',
                           'nexus/profile.peaks-union', 'seq/profile.peaks-union'
                           ]]) + 1).T);
fig.savefig(fdir / 'motif-venn.heatmap.all.pdf')
In [204]:
fig, ax = plt.subplots(figsize=get_figsize(1.4, 1/19))
sns.heatmap(np.log10(remove_zeroes(dfpm[['nexus/profile','nexus/binary',
                           # 'seq/profile',# 'seq/binary',
                           # 'nexus/profile.peaks-union', 'seq/profile.peaks-union'
                           ]]) + 1).T);
fig.savefig(fdir / 'nexus.profile-vs-binary.motif-venn.heatmap.pdf')
In [223]:
fig, ax = plt.subplots(figsize=get_figsize(1.1, 1/19))
sns.heatmap(np.log10(remove_zeroes(dfpm[['nexus/profile','nexus/binary',
                           # 'seq/profile',# 'seq/binary',
                           # 'nexus/profile.peaks-union', 'seq/profile.peaks-union'
                           ]]) + 1).T,
           xticklabels=True);
fig.savefig(fdir / 'nexus.profile-vs-binary.motif-venn.heatmap.pdf')
In [224]:
fig, ax = plt.subplots(figsize=get_figsize(1.2, 1/12))
sns.heatmap(np.log10(remove_zeroes(dfpm[[
                           'nexus/profile.peaks-union', 
                           'seq/profile.peaks-union',
                           'seq/profile',
                           'seq/binary',
                           
                           ]]) + 1).T);
fig.savefig(fdir / 'nexus-vs-seq.motif-venn.heatmap.pdf')
In [218]:
fdir
Out[218]:
PosixPath('/users/avsec/workspace/basepair/data/figures/method-comparison/modisco')
In [104]:
## TODO - was it more difficult to call motif on the peak union?
## - number of seqlets?
In [76]:
len(dfp)
Out[76]:
266
In [78]:
len(aligned_patterns)
Out[78]:
266
In [83]:
dict(row)
Out[83]:
{'name': 'nexus/profile.peaks-union/Nanog/m0_p13',
 'cluster20': 8,
 'cluster30': 29,
 'cluster40': 33,
 'cluster50': 41,
 'cluster60': 44,
 'label': 'TE1',
 'n_seqlets': 115,
 'method': 'nexus/profile.peaks-union'}

Plot all the aligned patterns

In [37]:
len(aligned_patterns)
Out[37]:
264
In [206]:
dfp
Out[206]:
name cluster20 cluster30 cluster40 cluster50 cluster60 label erronious n_seqlets method
0 seq/profile.peaks-uni... 8 29 37 42 46 TE1 False 115 seq/profile.peaks-union
1 seq/binary/Nanog/m0_p16 8 29 37 42 46 TE1 False 128 seq/binary
2 nexus/profile/Nanog/m... 8 29 37 42 46 TE1 False 127 nexus/profile
... ... ... ... ... ... ... ... ... ... ...
261 nexus/binary/Nanog/m0_p6 2 2 7 7 7 Klf4-long False 747 nexus/binary
262 nexus/binary/Oct4/m0_p8 2 2 7 7 7 Klf4-long False 497 nexus/binary
263 nexus/profile/Klf4/m0_p5 2 2 7 7 7 Klf4-long False 434 nexus/profile

264 rows × 10 columns

In [217]:
for i, p in enumerate(aligned_patterns):
    p.plot("seq_ic");
    row = dfp[dfp.name == shorten_name(p.name)].iloc[0]
    err = 'erronious ' if row.erronious else ''
    plt.title(err + f"{row.label} {row['name']}")
    # plt.title(p.name)
    plt.ylim([0, 2])
    sns.despine(bottom=True, right=True, top=True)