Goal

  • make sub-plots for Figure 5

Open questions

  • should we make a larger co-occurence plot?
    • plot more motifs (one from each cluster at least)
In [1]:
# Imports
from basepair.imports import *
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.exp.chipnexus.perturb.vdom import vdom_motif_pair, plot_spacing_hist
from basepair.exp.chipnexus.spacing import remove_edge_instances, get_motif_pairs, motif_pair_dfi
from basepair.exp.paper.config import tasks
from plotnine import *
import plotnine
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
Using TensorFlow backend.
In [2]:
odir = Path('../../chipnexus/train/seqmodel/output/')
In [141]:
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
motifs = OrderedDict([('Oct4-Sox2', 'm0_p0'),
             ('Sox2', 'm0_p3'),
             ('Nanog', 'm0_p1')])
             # ('Klf4', 'm1_p0')])
    
model_dir = odir / exp
# Common paths
modisco_dir = model_dir / f"deeplift/Sox2/out/profile/wn"
In [143]:
mr = ModiscoResult(modisco_dir / 'modisco.h5')
In [149]:
!ls {modisco_dir}
centroid_seqlet_matches.csv	     patterns.pkl
cluster-patterns.html		     pattern_table.csv
cluster-patterns.ipynb		     pattern_table.html
footprints.pkl			     pattern_table.sorted.csv
hparams.yaml			     pattern_table.sorted.html
instances.parq			     plots
kwargs.json			     report.html
log				     report.ipynb
modisco.h5			     results.html
modisco-instances.smk-benchmark.tsv  results.ipynb
modisco.smk-benchmark.tsv	     seqlets
motif_clustering
In [150]:
patterns = read_pkl(modisco_dir / 'patterns.pkl')
In [152]:
name2pattern = {p.name: p for p in patterns}
In [163]:
# for p in patterns:
#     p.plot(["profile"]);
In [162]:
name2pattern['metacluster_0/pattern_3'].plot("profile");
In [ ]:
## Note - seems that 
In [144]:
p = mr.get_pattern('metacluster_0/pattern_3')
In [126]:
tf = 'Nanog'

exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
motifs = OrderedDict([('Oct4-Sox2', 'm0_p0'),
             ('Sox2', 'm0_p3'),
             ('Nanog', 'm0_p1')])
             # ('Klf4', 'm1_p0')])
    
model_dir = odir / exp
# Common paths
modisco_dir = model_dir / f"deeplift/{tf}/out/profile/wn"

    
exp = 'nexus,gw,OSNK,1,0,0,FALSE,same,0.5,64,25,0.001,9,FALSE'

motifs = OrderedDict([('Oct4-Sox2', 'm0_p0'),
             ('Sox2', 'm0_p5'),
             ('Nanog', 'm0_p2'),
             ('Nanog2', 'm0_p8'),
             ('Klf4', 'm0_p1')
             ])
             # ('Klf4', 'm1_p0')])

# chipseq    
exp = 'seq,gw,OSN,1,0,0,FALSE,same,0.5,64,50,0.001,9,FALSE'


model_dir = odir / exp
# Common paths
modisco_dir = model_dir / f"deeplift/{tf}/out/class/pre-act"

motifs = OrderedDict([('Oct4-Sox2', 'm0_p0'),
             # ('Sox2', 'm0_p5'),
             ('Nanog', 'm0_p3'),
             #('Nanog2', 'm0_p8'),
             #('Klf4', 'm0_p1')
             ])
             # ('Klf4', 'm1_p0')])    
exp = 'seq,peaks,OSN,0,10,1,FALSE,same,0.5,64,50,0.004,9,FALSE,[1,50],TRUE/'    



motifs = OrderedDict([('Oct4-Sox2', 'm0_p0'),
             # ('Sox2', 'm0_p5'),
             ('Nanog', 'm0_p1'),
             #('Nanog2', 'm0_p8'),
             #('Klf4', 'm0_p1')
             ])
             # ('Klf4', 'm1_p0')])    

model_dir = odir / exp
# Common paths
modisco_dir = model_dir / f"deeplift/{tf}/out/profile/wn"

# figures dir
fdir = Path(f'{ddir}/figures/modisco/{exp}/spacing/')
In [127]:
!mkdir -p {fdir}/individual
In [128]:
!mkdir -p {fdir}
In [129]:
# define the global set of distances
dist_subsets = ['center_diff<=35',
               '(center_diff>35)&(center_diff<=70)', 
               '(center_diff>70)&(center_diff<=150)', 
               'center_diff>150']
dist_subset_labels = ['dist < 35',
                      '35 < dist <= 70',
                      '70 < dist <= 150',
                      '150 < dist',
                     ]
In [130]:
ls {model_dir}/deeplift/Nanog/out/profile/wn/instances.parq/pattern=metacluster_0/pattern_0/
part.0.parquet
In [131]:
!ls ../../chipnexus/train/seqmodel/output/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE/deeplift/Oct4/out/profile/wn/instances.parq
_common_metadata  _metadata  pattern=metacluster_0

Load motif pairs

In [132]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
In [133]:
pairs = get_motif_pairs(motifs)

# ordered names
pair_names = ["<>".join(x) for x in pairs]
In [134]:
def load_instances(parq_file, motifs=None, dedup=True):
    """Load pattern instances from the parquet file

    Args:
      parq_file: parquet file of motif instances
      motifs: dictionary of motifs of interest.
        key: custom motif name, value: short pattern name (e.g. 'm0_p3')

    """
    if motifs is not None:
        incl_motifs = {longer_pattern(m) for m in motifs.values()}
    else:
        incl_motifs = None

    if isinstance(parq_file, pd.DataFrame):
        dfi = parq_file
    else:
        if motifs is not None:
            from fastparquet import ParquetFile
            # Selectively load only the relevant patterns
            pf = ParquetFile(str(parq_file))
            if 'dir0' in pf.cats:
                # fix the wrong patterns
                metaclusters = list({'pattern=' + x.split("/")[0] for x in incl_motifs})
                patterns = list({x.split("/")[1] for x in incl_motifs})
                dfi = pf.to_pandas(filters=[("dir0", "in", metaclusters),
                                            ("dir1", "in", patterns)])
                dfi['pattern'] = dfi['dir0'].str.replace("pattern=", "").astype(str) + "/" + dfi['dir1'].astype(str)
                del dfi['dir0']
                del dfi['dir1']
            else:
                dfi = pd.to_pandas(filters=[('pattern', 'in', list(incl_motifs))])
        else:
            dfi = pd.read_parquet(str(parq_file), engine='fastparquet')

    # filter
    if motifs is not None:
        dfi = dfi[dfi.pattern.isin(incl_motifs)]  # NOTE this should already be removed
        dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k) for k in incl_motifs})
        dfi['pattern_name'] = dfi['pattern_short'].map({v: k for k, v in motifs.items()})
    else:
        dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k)
                                                   for k in dfi.pattern.unique()})

    # add some columns
    dfi['pattern_start_abs'] = dfi['example_start'] + dfi['pattern_start']
    dfi['pattern_end_abs'] = dfi['example_start'] + dfi['pattern_end']

    if dedup:
        # deduplicate
        dfi_dedup = dfi.drop_duplicates(['pattern',
                                         'example_chrom',
                                         'pattern_start_abs',
                                         'pattern_end_abs',
                                         'strand'])

        # number of removed duplicates
        d = len(dfi) - len(dfi_dedup)
        print("number of de-duplicated instances:", d, f"({d / len(dfi) * 100}%)")

        # use de-duplicated instances from now on
        dfi = dfi_dedup
    return dfi

TODO

  • [ ] save load_instances once this works
  • [ ] load a single dfi for multiple modisco runs
In [20]:
from fastparquet import ParquetFile
# Selectively load only the relevant patterns
pf = ParquetFile(str(modisco_dir / 'instances.parq'))

# TODO - debug loading from a data subset
In [ ]:
# dfi = pd.read_parquet(str(modisco_dir / 'instances.parq'), engine='fastparquet')
In [95]:
!ls '../../chipnexus/train/seqmodel/output/nexus,gw,OSNK,1,0,0,FALSE,same,0.5,64,25,0.001,9,FALSE/deeplift/Nanog/out/profile/wn/'
ls: cannot access '../../chipnexus/train/seqmodel/output/nexus,gw,OSNK,1,0,0,FALSE,same,0.5,64,25,0.001,9,FALSE/deeplift/Nanog/out/profile/wn/': No such file or directory
In [135]:
dfi = load_instances(modisco_dir / 'instances.parq', motifs, dedup=False)
dfi = filter_nonoverlapping_intervals(dfi)
In [136]:
dfi_subset = dfi.query('match_weighted_p > .2').query('imp_weighted_p > 0')
In [137]:
# create motif pairs
dfab = pd.concat([motif_pair_dfi(dfi_subset, pair).assign(motif_pair="<>".join(pair)) for pair in pairs], axis=0)

Co-occurence test

<150bp

In [62]:
from basepair.exp.chipnexus.spacing import coocurrence_plot
from basepair.exp.chipnexus.spacing import co_occurence_matrix
In [66]:
# debug
dfi_subset['pattern_strand_aln'] = dfi_subset['strand']
dfi_subset['pattern_center_aln'] = dfi_subset['pattern_center']
In [67]:
# co-occurence
fig, ax = plt.subplots(figsize=get_figsize(.25, aspect=1))
dist = 150
coocurrence_plot(dfi_subset, list(motifs),
                 query_string=f"(abs(pattern_center_aln_x- pattern_center_aln_y) <= {dist})")
ax.set_ylabel("Motif of interest")
ax.set_xlabel("Motif partner");
fig.savefig(fdir / f'coocurrence.test.center_diff<{dist}.pdf')

All distance ranges

In [68]:
# subsets = ['center_diff <= 35', 'center_diff > 35', 'center_diff > 70']
subsets = dist_subsets
fig,axes = plt.subplots(1, len(subsets), 
                        figsize=get_figsize(0.25*len(subsets), .8/len(subsets)), 
                        sharey=True)
for i, (subset,ax, subset_label) in enumerate(zip(subsets, axes, dist_subset_labels)):
    if i == len(subsets) - 1:
        cbar = True
    else:
        cbar = False
    coocurrence_plot(dfi_subset, list(motifs), query_string=subset, ax=ax, cbar=cbar)
    if i == 0:
        ax.set_ylabel("Motif of interest")
        ax.set_xlabel("Motif partner");
    
    ax.set_title(subset_label)
# plt.tight_layout()
fig.savefig(fdir / f'coocurrence.test.all-dist.pdf')

Pairwise distance distribution

Seq - profile

In [138]:
plotnine.options.figure_size = get_figsize(.5, aspect=2)# (10, 10)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair != 'Oct4-Sox2<>Sox2')]) + 
 geom_histogram(bins=max_dist) + 
 facet_grid("motif_pair~ .") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 ylim([0, 1000]) + 
 xlab("Pairwise distance") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'histogram.center_diff.all.pdf')
fig 
Out[138]:
<ggplot: (-9223363291122037623)>
In [139]:
plotnine.options.figure_size = get_figsize(.3, aspect=2/10*4)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair.isin(["Nanog<>Nanog"]))]) + 
 geom_vline(xintercept=10, alpha=0.1) + 
 geom_vline(xintercept=20, alpha=0.1) + 
 geom_vline(xintercept=30, alpha=0.1) + 
 geom_vline(xintercept=40, alpha=0.1) + 
 geom_histogram(bins=max_dist) + facet_grid("strand_combination~.") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") +
 ggtitle("Nanog<>Nanog") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'individual/nanog.spacing.pdf')
fig
Out[139]:
<ggplot: (-9223363291121854916)>

Seq - binary

In [124]:
plotnine.options.figure_size = get_figsize(.5, aspect=2)# (10, 10)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair != 'Oct4-Sox2<>Sox2')]) + 
 geom_histogram(bins=max_dist) + 
 facet_grid("motif_pair~ .") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 ylim([0, 1000]) + 
 xlab("Pairwise distance") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'histogram.center_diff.all.pdf')
fig 
Out[124]:
<ggplot: (-9223363291567501943)>
In [125]:
plotnine.options.figure_size = get_figsize(.3, aspect=2/10*4)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair.isin(["Nanog<>Nanog"]))]) + 
 geom_vline(xintercept=10, alpha=0.1) + 
 geom_vline(xintercept=20, alpha=0.1) + 
 geom_vline(xintercept=30, alpha=0.1) + 
 geom_vline(xintercept=40, alpha=0.1) + 
 geom_histogram(bins=max_dist) + facet_grid("strand_combination~.") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") +
 ggtitle("Nanog<>Nanog") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'individual/nanog.spacing.pdf')
fig
Out[125]:
<ggplot: (-9223363291567658510)>

Nexus - binary

In [110]:
plotnine.options.figure_size = get_figsize(.5, aspect=2)# (10, 10)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair != 'Oct4-Sox2<>Sox2')]) + 
 geom_histogram(bins=max_dist) + 
 facet_grid("motif_pair~ .") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 ylim([0, 1000]) + 
 xlab("Pairwise distance") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'histogram.center_diff.all.pdf')
fig 
Out[110]:
<ggplot: (8746067683988)>
In [109]:
plotnine.options.figure_size = get_figsize(.3, aspect=2/10*4)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair.isin(["Nanog<>Nanog"]))]) + 
 geom_vline(xintercept=10, alpha=0.1) + 
 geom_vline(xintercept=20, alpha=0.1) + 
 geom_vline(xintercept=30, alpha=0.1) + 
 geom_vline(xintercept=40, alpha=0.1) + 
 geom_histogram(bins=max_dist) + facet_grid("strand_combination~.") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") +
 ggtitle("Nanog<>Nanog") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'individual/nanog.spacing.pdf')
fig
Out[109]:
<ggplot: (-9223363290624282869)>

Profile models

In [80]:
plotnine.options.figure_size = get_figsize(.5, aspect=2)# (10, 10)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair != 'Oct4-Sox2<>Sox2')]) + 
 geom_histogram(bins=max_dist) + 
 facet_grid("motif_pair~ .") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'histogram.center_diff.all.pdf')
fig 
Out[80]:
<ggplot: (-9223363290629183821)>

Nanog

In [83]:
plotnine.options.figure_size = get_figsize(.3, aspect=2/10*4)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair.isin(["Nanog<>Nanog"]))]) + 
 geom_vline(xintercept=10, alpha=0.1) + 
 geom_vline(xintercept=20, alpha=0.1) + 
 geom_vline(xintercept=30, alpha=0.1) + 
 geom_vline(xintercept=40, alpha=0.1) + 
 geom_histogram(bins=max_dist) + facet_grid("strand_combination~.") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") +
 ggtitle("Nanog<>Nanog") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'individual/nanog.spacing.pdf')
fig
Out[83]:
<ggplot: (8746225588165)>
In [23]:
plotnine.options.figure_size = get_figsize(.3, aspect=2/10*4)
max_dist = 100
fig = (ggplot(aes(x='center_diff', fill='strand_combination'), dfab[(dfab.center_diff <= max_dist) & (dfab.motif_pair.isin(["Nanog<>Nanog"]))]) + 
 geom_vline(xintercept=10, alpha=0.1) + 
 geom_vline(xintercept=20, alpha=0.1) + 
 geom_vline(xintercept=30, alpha=0.1) + 
 geom_vline(xintercept=40, alpha=0.1) + 
 geom_histogram(bins=max_dist) + facet_grid("strand_combination~.") + 
 theme_classic(base_size=10, base_family='Arial') + 
 theme(strip_text = element_text(rotation=0), legend_position='top') + 
 xlim([0, max_dist]) + 
 xlab("Pairwise distance") +
 ggtitle("Nanog<>Nanog") + 
 scale_fill_brewer(type='qual', palette=3))
fig.save(fdir / 'individual/nanog.spacing.pdf')
fig
Out[23]:
<ggplot: (-9223363310659608965)>

In [33]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
In [40]:
from basepair.plot.profiles import extract_signal
from basepair.math import softmax
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
from basepair.plot.profiles import  plot_stranded_profile, multiple_plot_stranded_profile
from basepair.modisco.results import Seqlet, resize_seqlets
from basepair.modisco.pattern_instances import dfi2seqlets, annotate_profile
from basepair.cli.modisco import load_profiles
from basepair.BPNet import BPNetPredictor
from basepair.plot.tracks import plot_tracks, filter_tracks
from basepair.preproc import rc_seq
from copy import deepcopy
from basepair.preproc import dfint_no_intersection
from basepair.exp.chipnexus.simulate import (insert_motif, generate_sim, plot_sim, generate_seq, 
                                             model2tasks, motif_coords, interactive_tracks, plot_motif_table,
                                             plot_sim_motif_col)

from scipy.fftpack import fft, ifft

# interval columns in dfi
interval_cols = ['example_chrom', 'pattern_start_abs', 'pattern_end_abs']
In [35]:
paper_config()
In [36]:
# create_tf_session(0)
In [37]:
from basepair.exp.paper.config import motifs, side_motifs, modisco_dir, model_dir, imp_score_file
In [12]:
figures = Path(f"{ddir}/figures/modisco/spacing")
In [27]:
!mkdir -p {figures}

Load the required data

In [14]:
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()

# Load the data
d = HDF5Reader(imp_score_file)
d.open()

patterns = read_pkl(modisco_dir / "patterns.pkl")  # aligned patterns
In [17]:
# plot them
for tf, p in motifs.items():
    mr.get_pattern(longer_pattern(p)).trim_seq_ic(0.08).plot("seq");
    plt.title(tf)
In [38]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
In [19]:
dfi_full = pd.read_parquet(f"{modisco_dir}/instances.parq", engine='fastparquet')
In [20]:
# get transposable element locations
te_patterns = pd.read_csv(modisco_dir / "motif_clustering/patterns_long.csv").pattern.map(longer_pattern).unique()
In [21]:
dfi_te = load_instances(dfi_full[dfi_full.pattern.isin(te_patterns)], None)
dfi_te = dfi_te[(dfi_te.match_weighted_p > 0.1) & (dfi_te.seq_match > 20)]
number of deduplicatd instances: 16611 (54.877927913046356%)
In [22]:
dfi = load_instances(dfi_full, motifs)

dfi = filter_nonoverlapping_intervals(dfi)

total_examples = len(dfi.example_idx.unique())
total_examples
number of deduplicatd instances: 1594551 (35.14740608536545%)
Out[22]:
60345
In [23]:
# annotate with profiles
profiles = load_profiles(modisco_dir, imp_score_file)
dfi_anno = annotate_profile(dfi, mr, profiles, profile_width=70, trim_frac=0.08)
100%|██████████| 4/4 [40:51<00:00, 664.52s/it]
In [24]:
dfi = dfi_anno
In [25]:
# add pattern_center_aln features
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
dfi = align_instance_center(dfi, orig_patterns, patterns, trim_frac=0.08)
In [28]:
dfi.pattern_name.value_counts()
Out[28]:
Nanog        899892
Oct4-Sox2    294321
Klf4         285485
Sox2         250038
Name: pattern_name, dtype: int64

Plot the profiles - sanity check

In [29]:
dfi_subset = dfi[(dfi.pattern_center > 400) & (dfi.pattern_center < 600)].query('match_weighted_p > .5').\
                  query('imp_weighted_p > 0').query('pattern_name=="Oct4-Sox2"')
In [30]:
seqlets = dfi2seqlets(dfi_subset)
seqlets = resize_seqlets(seqlets, 70, seqlen=1000)
seqlet_profiles = {k: extract_signal(v, seqlets) for k,v in profiles.items()}
In [31]:
multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));
In [60]:
p = [p for p in patterns if p.name == 'metacluster_0/pattern_0'][0]
In [70]:
seqlet_profiles = OrderedDict([(k, v[:, 65:135]) for k,v in p.attrs['stacked_seqlet_imp'].profile.items()])
In [68]:
multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));
In [57]:
# pattern lengths
dfi.groupby(['pattern_name', 'pattern_len']).size()
Out[57]:
pattern_name  pattern_len
Klf4          10             168219
Nanog         9              587276
Oct4-Sox2     15             206314
Sox2          22             169672
dtype: int64

Plots

Co-occurence matrix

In [72]:
# TODO - allow self
In [59]:
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
plot_coocurence_matrix(dfi[(dfi.pattern_center > 450) & 
                            (dfi.pattern_center < 550)].
                             query('match_weighted_cat!="low"').
                             query('imp_weighted_cat=="high"'), total_examples, ax=ax)
In [71]:
fig.savefig(figures / "co-occurence-test.pdf")

Spacing plot (exclude TE's)

Re-implement spacing

In [47]:
# filtered instances
dfi_filtered = (dfi.query('match_weighted_p > 0.1').query('imp_weighted_p > 0'))
keep_nonte = dfint_no_intersection(dfi_filtered[interval_cols], dfi_te[interval_cols])
print("not overlapping TE", keep_nonte.mean())  # almost all were kept
dfi_filtered = dfi_filtered[keep_nonte]
not overlapping TE 0.9622691422132954
In [42]:
# filtered
dfi_filtered.pattern_name.value_counts()
Out[42]:
Nanog        30601
Klf4         18633
Oct4-Sox2    10364
Sox2          6944
Name: pattern_name, dtype: int64
In [43]:
# original
dfi.pattern_name.value_counts()
Out[43]:
Nanog        899892
Oct4-Sox2    294321
Klf4         285485
Sox2         250038
Name: pattern_name, dtype: int64
In [44]:
# setup config
from basepair.modisco.pattern_instances import construct_motif_pairs

pairs = []
for i in range(len(motifs)):
    for j in range(i, len(motifs)):
        pairs.append([ list(motifs)[i], list(motifs)[j], ])

comp_strand_compbination = {
    "++": "--",
    "--": "++",
    "-+": "-+",
    "+-": "+-"
}

strand_combinations = ["++", "--", "+-", "-+"]

profile_mapping = {
    "Oct4-Sox2": "Oct4",
    "Sox2": "Sox2",
    "Nanog": "Nanog",
    "Klf4": "Klf4",
    "Essrb": "Oct4"
}

def motif_pair_dfi(dfi_filtered, motif_pair):
    """Construct the matrix of motif pairs
    
    Args:
      dfi_filtered: dfi filtered to the desired property
      motif_pair: tuple of two pattern_name's 
    Returns:
      pd.DataFrame with columns from dfi_filtered with _x and _y suffix
    """
    dfa = dfi_filtered[dfi_filtered.pattern_name == motif_pair[0]]
    dfb = dfi_filtered[dfi_filtered.pattern_name == motif_pair[1]]

    dfab = pd.merge(dfa, dfb, on='example_idx', how='outer')
    dfab = dfab[~dfab[['pattern_x', 'pattern_y']].isnull().any(1)]

    dfab['center_diff'] = dfab.pattern_center_y - dfab.pattern_center_x
    dfab['center_diff_aln'] = dfab.pattern_center_aln_y - dfab.pattern_center_aln_x
    dfab['strand_combination'] = dfab.strand_x + dfab.strand_y
    # assure the right strand combination
    dfab[dfab.center_diff < 0]['strand_combination'] = dfab[dfab.center_diff < 0]['strand_combination'].map(comp_strand_compbination)

    if motif_pair[0] == motif_pair[1]:
        dfab['strand_combination'][dfab['strand_combination'] == "--"] = "++"
        dfab = dfab[dfab.center_diff > 0]
    else:
        dfab.center_diff = np.abs(dfab.center_diff)
        dfab.center_diff_aln = np.abs(dfab.center_diff_aln)
    dfab = dfab[dfab.center_diff_aln != 0]  # exclude perfect matches
    return dfab


def plot_spacing(dfab, 
                 alpha_scatter=0.01, 
                 y_feature='profile_counts', 
                 center_diff_variable='center_diff', 
                 figsize=(3.42519, 6.85038)):
    from basepair.stats import smooth_window_agg, smooth_lowess, smooth_gam
    
    motif_pair = (dfab.iloc[0].pattern_name_x, dfab.iloc[0].pattern_name_y)
    strand_combinations = dfab.strand_combination.unique()
    fig_profile, axes = plt.subplots(2*len(strand_combinations), 1, figsize=figsize, sharex=True, sharey='row')
    
    motif_pair_c = motif_pair
    axes[0].set_title("<>".join(motif_pair), fontsize=7)

    j = 0  # first column

    dftw_filt = dfab[(dfab.center_diff < 150)] #  & (dfab.imp_weighted_p.max(1) > 0.3)]
    for i, sc in enumerate(strand_combinations):
        if y_feature == 'profile_counts':
            y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[0]] + "/profile_counts_x"])
            y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[1]] + "/profile_counts_y"])
        elif y_feature == 'imp_weighted':
            y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_x'])
            y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_y'])
        else:
            raise ValueError(f"Unkown y_feature: {y_feature}")
            
        # y1 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[0]]
        # y2 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[1]]
        x = dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable]

        #dm,ym,confi = average_distance(x,y, window=5)
        dm1,ym1,confi1 = smooth_lowess(x,y1, frac=0.15)
        dm2,ym2,confi2 = smooth_lowess(x,y2, frac=0.15)
        #dm,ym, confi = smooth_gam(x,y, 140, 20)

        ax = axes[2*i]
        ax.hist(dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable], np.arange(10, 150, 1));
        if j == 0:
            ax.set_ylabel(sc)

        # second plot
        ax.set_xlim([0, 150])
        ax = axes[2*i+1]
        ax.scatter(x,y1, alpha=alpha_scatter, s=8)
        if confi1 is not None:
            ax.fill_between(dm1, confi1[:,0], confi1[:,1], alpha=0.2)
        ax.plot(dm1, ym1, linewidth=2, alpha=0.8)
        ax.scatter(x,y2, alpha=alpha_scatter, s=8)
        if confi2 is not None:
            ax.fill_between(dm2, confi2[:,0], confi2[:,1], alpha=0.2)
        ax.plot(dm2, ym2, linewidth=2, alpha=0.8)
        if j == 0:
            ax.set_ylabel(sc)
        ax.xaxis.set_minor_locator(plt.MultipleLocator(10))
        ax.xaxis.set_major_locator(plt.MultipleLocator(20))
        if j == 0:
            ax.set_ylabel(sc)
        if i == len(strand_combinations) - 1:
            ax.set_xlabel("Distance between motifs")
    fig_profile.subplots_adjust(wspace=0, hspace=0) 
    return fig_profile
In [48]:
dfab = motif_pair_dfi(dfi_filtered, ['Oct4-Sox2', 'Sox2'])
In [49]:
fig = plot_spacing(dfab, alpha_scatter=0.05, y_feature='profile_counts', figsize=get_figsize(.4, aspect=2))
In [ ]:
 
In [235]:
!mkdir -p {figures}/'individual'

Save all pairs to pdf

In [ ]:
for motif_pair in tqdm(pairs):
    dfab = motif_pair_dfi(dfi_filtered, motif_pair)
    for yf in ['profile_counts', 'imp_weighted']:
        fig = plot_spacing(dfab, alpha_scatter=0.05, y_feature=yf, figsize=get_figsize(.4, aspect=2))
        mp_name = "<>".join(motif_pair)
        fig.savefig(figures/ f'individual/{mp_name}.{yf}.pdf')

Simulation data

In [78]:
# plt.style.use('default')
In [50]:
in_silico = read_pkl(f"{ddir}/processed/chipnexus/simulation/spacing.pkl")
In [51]:
sim_df_d, sim_res_dict_d = in_silico
In [52]:
sim_df_d.keys()
Out[52]:
dict_keys(['Oct4-Sox2', 'Sox2', 'Essrb', 'Nanog', 'Klf4', 'Oct4-Sox2/rc', 'Sox2/rc', 'Essrb/rc', 'Nanog/rc', 'Klf4/rc'])
In [53]:
def swap_orientation(o):
    if o=='+':
        return "-"
    if o=="-":
        return "+"
    raise ValueError("")

# Very much a hacked function to map simulation to in vivo data
def get_xy_sim_single(sim_df_d, motif_pair, feature, orientation):
    # For Nanog, always explicilty swap the orientation
    orientation_pair = [orientation[0], orientation[1]]
    # HACK! Nanog orientation didn't properly match the orientation
    if motif_pair[0] == "Nanog":
        orientation_pair[0] = swap_orientation(orientation_pair[0])
    if motif_pair[1] == "Nanog":
        orientation_pair[1] = swap_orientation(orientation_pair[1])

    mp = list(deepcopy(motif_pair))
    if orientation_pair[0] == "-":
        mp[0] = mp[0] + "/rc"
    if orientation_pair[1] == "-":
        mp[1] = mp[1] + "/rc"
        
    df = sim_df_d[mp[0]]  # choose the central motif
    df = df[df.motif == mp[1]]  # choose the side motif
    
    df = df[df.distance < 150]
    
    # select the task
    df = df[df.task == profile_mapping[motif_pair[0]]]        
    return df.distance.values, df[feature].values
    
def get_xy_sim(sim_df_d, motif_pair, feature, orientation):
    x1,y1 = get_xy_sim_single(sim_df_d, motif_pair, feature, orientation)
    x2,y2 = get_xy_sim_single(sim_df_d, list(reversed(motif_pair)), feature, comp_strand_compbination[orientation])
    assert np.all(x1 == x2)
    return x1, y1, y2
In [54]:
sim_df_d['Oct4-Sox2'].head()
Out[54]:
central_motif distance imp/count imp/count_frac imp/weighted imp/weighted_frac position profile/counts profile/counts_frac profile/max profile/max_frac profile/simmetric_kl side_motif task motif
0 TTTGCATAACAA 11 0.9609 0.6405 0.9823 0.8519 511 185.7749 2.3054 12.2156 2.2955 0.1719 TTTGCATAACAA Oct4 Oct4-Sox2
1 TTTGCATAACAA 11 0.5739 0.6229 0.6065 0.7806 511 56.8762 1.9200 5.4221 2.1760 0.3281 TTTGCATAACAA Sox2 Oct4-Sox2
2 TTTGCATAACAA 11 0.6132 0.6085 0.3001 0.6910 511 37.6371 1.3918 1.5615 1.9475 0.1569 TTTGCATAACAA Nanog Oct4-Sox2
3 TTTGCATAACAA 11 0.2450 0.5783 0.2800 0.6075 511 14.7951 1.0222 0.7564 1.4350 0.1079 TTTGCATAACAA Klf4 Oct4-Sox2
4 TTTGCATAACAA 12 0.5861 0.3907 0.4954 0.4297 512 111.2185 1.3802 4.5952 0.8635 0.1345 TTTGCATAACAA Oct4 Oct4-Sox2
In [85]:
dfi_filtered = (dfi[(dfi.pattern_center > 400) & (dfi.pattern_center < 600)]
                .query('match_weighted_p > 0.2')
                .query('imp_weighted_p > 0'))
keep_nonte = dfint_no_intersection(dfi_filtered[interval_cols], dfi_te[interval_cols])
print("not overlapping TE", keep_nonte.mean())  # almost all were kept
dfi_filtered = dfi_filtered[keep_nonte]
d_dfi_filtered=[]
for motif_pair in tqdm(pairs):
    dftw_filt = construct_motif_pairs(dfi_filtered, motif_pair,
                                     features=['match_weighted_p','imp_weighted_p', 'imp_weighted'] + \
                                      [f for f in dfi if "profile" in f])
    
    # assure the right strand combination
    dftw_filt[dftw_filt.center_diff < 0]['strand_combination'] = dftw_filt[dftw_filt.center_diff < 0]['strand_combination'].map(comp_strand_compbination)
    dftw_filt.center_diff = np.abs(dftw_filt.center_diff)
    
    if motif_pair[0] == motif_pair[1]:
        motif_pair = [f"{m}{i}" for i,m in enumerate(motif_pair)]
    d_dfi_filtered.append((motif_pair, dftw_filt))
  0%|          | 0/10 [00:00<?, ?it/s]
not overlapping TE 0.9445685825817741
100%|██████████| 10/10 [00:11<00:00,  1.19s/it]

Response = Profile

In [86]:
fig_profile, all_axes = plt.subplots(3*len(strand_combinations), len(d_dfi_filtered), figsize=(25,15), sharex=True, sharey='row')

for j, (motif_pair, dftw_filt) in enumerate(d_dfi_filtered):
    axes = all_axes[:,j]
    if motif_pair[0][:-1] == motif_pair[1][:-1]:
        axes[0].set_title("{f}<>{f}".format(f=motif_pair[0][:-1]), fontsize=7)
        motif_pair_c = [mp[:-1] for mp in motif_pair]
    else:
        motif_pair_c = motif_pair
        axes[0].set_title("<>".join(motif_pair), fontsize=7)
    dftw_filt = dftw_filt[(dftw_filt.center_diff < 150) & (dftw_filt.imp_weighted_p.max(1) > 0.3)]

    
    ymax = max([np.log10(1+dftw_filt[profile_mapping[mp] + "/profile_counts"]).max().max() for mp in motif_pair_c])
    ymin = min([np.log10(1+dftw_filt[profile_mapping[mp] + "/profile_counts"]).min().min() for mp in motif_pair_c])
    #ymax = dftw_filt.imp_weighted.max().max()
    #ymin = dftw_filt.imp_weighted.min().min()
    for i, sc in enumerate(strand_combinations):
        y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[0]] + "/profile_counts"][motif_pair[0]])
        y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[1]] + "/profile_counts"][motif_pair[1]])
        # y1 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[0]]
        # y2 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[1]]
        x = dftw_filt[dftw_filt.strand_combination==sc]['center_diff']

        #dm,ym,confi = average_distance(x,y, window=5)
        dm1,ym1,confi1 = smooth_lowess(x,y1, frac=0.15)
        dm2,ym2,confi2 = smooth_lowess(x,y2, frac=0.15)
        #dm,ym, confi = smooth_gam(x,y, 140, 20)

        ax = axes[3*i]
        ax.hist(dftw_filt[dftw_filt.strand_combination==sc]['center_diff'], np.arange(10, 150, 1));
        if j == 0:
            ax.set_ylabel(sc)

        # second plot
        ax = axes[3*i+1]
        ax.scatter(x,y1, alpha=0.05, s=8)
        if confi1 is not None:
            ax.fill_between(dm1, confi1[:,0], confi1[:,1], alpha=0.2)
        ax.plot(dm1, ym1, linewidth=2, alpha=0.8)
        ax.scatter(x,y2, alpha=0.05, s=8)
        if confi2 is not None:
            ax.fill_between(dm2, confi2[:,0], confi2[:,1], alpha=0.2)
        ax.plot(dm2, ym2, linewidth=2, alpha=0.8)
        if j == 0:
            ax.set_ylabel(sc)
        # third plot, simulated
        ax = axes[3*i+2]
        sim_x, sim_y1, sim_y2 = get_xy_sim(sim_df_d, motif_pair_c, 'profile/counts_frac', sc)
        ax.axhline(1, linestyle="--", color='grey', alpha=0.2)
        ax.plot(sim_x, sim_y1, linewidth=1, alpha=0.8)
        ax.plot(sim_x, sim_y2, linewidth=1, alpha=0.8)

        ax.xaxis.set_minor_locator(plt.MultipleLocator(10))
        if j == 0:
            ax.set_ylabel(sc)
    fig_profile.subplots_adjust(wspace=0, hspace=0)

Response=Importance

In [87]:
fig_imp, all_axes = plt.subplots(3*len(strand_combinations), len(d_dfi_filtered), figsize=(25,15), sharex=True, sharey='row')

for j, (motif_pair, dftw_filt) in enumerate(d_dfi_filtered):
    axes = all_axes[:,j]
    if motif_pair[0][:-1] == motif_pair[1][:-1]:
        axes[0].set_title("{f}<>{f}".format(f=motif_pair[0][:-1]), fontsize=7)
        motif_pair_c = [mp[:-1] for mp in motif_pair]
    else:
        motif_pair_c = motif_pair
        axes[0].set_title("<>".join(motif_pair), fontsize=7)
    dftw_filt = dftw_filt[(dftw_filt.center_diff < 150) & (dftw_filt.imp_weighted_p.max(1) > 0.3)]

    #ymax = max([np.log10(1+dftw_filt[profile_mapping[mp] + "/profile_counts"]).max().max() for mp in motif_pair_c])
    #ymin = min([np.log10(1+dftw_filt[profile_mapping[mp] + "/profile_counts"]).min().min() for mp in motif_pair_c])
    ymax = dftw_filt.imp_weighted.max().max()
    ymin = dftw_filt.imp_weighted.min().min()
    for i, sc in enumerate(strand_combinations):
        #y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[0]] + "/profile_counts"][motif_pair[0]])
        #y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[1]] + "/profile_counts"][motif_pair[1]])
        y1 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[0]]
        y2 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[1]]
        x = dftw_filt[dftw_filt.strand_combination==sc]['center_diff']

        #dm,ym,confi = average_distance(x,y, window=5)
        dm1,ym1,confi1 = smooth_lowess(x,y1, frac=0.15)
        dm2,ym2,confi2 = smooth_lowess(x,y2, frac=0.15)
        #dm,ym, confi = smooth_gam(x,y, 140, 20)

        ax = axes[3*i]
        ax.hist(dftw_filt[dftw_filt.strand_combination==sc]['center_diff'], np.arange(10, 150, 1));
        if j == 0:
            ax.set_ylabel(sc)

        # second plot
        ax = axes[3*i+1]
        ax.scatter(x,y1, alpha=0.05, s=8)
        if confi1 is not None:
            ax.fill_between(dm1, confi1[:,0], confi1[:,1], alpha=0.2)
        ax.plot(dm1, ym1, linewidth=2, alpha=0.8)
        ax.scatter(x,y2, alpha=0.05, s=8)
        if confi2 is not None:
            ax.fill_between(dm2, confi2[:,0], confi2[:,1], alpha=0.2)
        ax.plot(dm2, ym2, linewidth=2, alpha=0.8)
        if j == 0:
            ax.set_ylabel(sc)
        # third plot, simulated
        ax = axes[3*i+2]
        sim_x, sim_y1, sim_y2 = get_xy_sim(sim_df_d, motif_pair_c, 'imp/weighted_frac', sc)
        ax.axhline(1, linestyle="--", color='grey', alpha=0.2)
        ax.plot(sim_x, sim_y1, linewidth=1, alpha=0.8)
        ax.plot(sim_x, sim_y2, linewidth=1, alpha=0.8)

        ax.xaxis.set_minor_locator(plt.MultipleLocator(10))
        if j == 0:
            ax.set_ylabel(sc)
fig_imp.subplots_adjust(wspace=0, hspace=0)
In [88]:
fig_profile.savefig(figures / "non-TE.profile_counts.pdf")
fig_imp.savefig(figures / "non-TE.importance_weighted.pdf")