Goal

  • investigate the effect of motif perturbations

Tasks

  • [x] find all high-confidence instances in the peaks
  • [x] systematically perturb all the instances and visualize the results
In [ ]:
# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
In [ ]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
output_dir = modisco_dir
In [26]:
create_tf_session(0)
Out[26]:
<tensorflow.python.client.session.Session at 0x7face7351be0>
In [5]:
patterns = read_pkl(modisco_dir / "patterns.pkl")  # aligned patterns
In [27]:
bpnet = BPNet.from_mdir(model_dir)
In [ ]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.utils import flatten
from kipoi.writers import HDF5BatchWriter
from basepair.exp.chipnexus.perturb import random_seq_onehot, PerturbSeqDataset, DoublePerturbSeqDataset, PerturbDataset, SingleMotifPerturbDataset, DoubleMotifPerturbDataset, OtherMotifPerturbDataset, PerturbationDataset
from basepair.exp.chipnexus.spacing import motif_pair_dfi, plot_spacing, get_motif_pairs

pairs = get_motif_pairs(motifs)
In [8]:
dfi = load_instances(modisco_dir / 'instances.parq', motifs=motifs, dedup=False)
dfi = filter_nonoverlapping_intervals(dfi)
In [9]:
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
In [10]:
# Add aligned instances
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
dfi = align_instance_center(dfi, orig_patterns, patterns, trim_frac=0.08)
TF-MoDISco is using the TensorFlow backend.
In [11]:
from basepair.cli.imp_score import ImpScoreFile
imp_scores = ImpScoreFile.from_modisco_dir(modisco_dir)

profiles = imp_scores.get_profiles()

Estimate the total output size

In [18]:
alt_seqs = PerturbSeqDataset(dfi_subset, seqs).load_all()
alt_preds = bpnet.predict(alt_seqs)
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True, batch_size=256)
alt_imp_scores_contrib = {k: v * alt_seqs for k,v in alt_imp_scores.items()}

Implement a function to go from seqlet -> ref / alt profile

In [248]:
pattern = 'metacluster_0/pattern_0'
In [262]:
from basepair.exp.chipnexus.simulate import profile_sim_metrics
from basepair.exp.chipnexus.perturb import get_reference_profile
ref_profiles = {p: get_reference_profile(mr, longer_pattern(sn), profiles, tasks) for p,sn in motifs.items()}

Dataset structure

  • ref
    • task
      • motif
        • obs
        • pred
        • imp
          • profile
          • count
      • total
        • pred
        • imp
          • profile
          • count
  • dA
  • dB
  • dAdB

Reference

In [16]:
# get predictions
seqs = imp_scores.get_seq()
%time preds = bpnet.predict(seqs)
# get importance scores
%time ref_imp_scores = bpnet.imp_score_all(seqs, method='deeplift', aggregate_strand=True)
ref_imp_scores_contrib = {k: v * seqs for k,v in ref_imp_scores.items()}
CPU times: user 30.9 s, sys: 13.6 s, total: 44.5 s
Wall time: 58.3 s
In [ ]:
ref = NumpyDataset({t: {
        "obs": profiles[task],
        "pred": preds[task],
        "imp": {
            "profile": ref_imp_scores_contrib[f"{t}/weighted"],
            "count": ref_imp_scores[f"{t}/count"],
        }} for t in tasks}, attrs={'index': 'example_idx'})

Single motif perturbation

In [12]:
# get the interesting motif location
dfi_subset = (dfi.query('match_weighted_p > 0.2')
                 .query('imp_weighted_p > 0'))
dfi_subset['row_idx'] = np.arange(len(dfi_subset)).astype(int)
In [1293]:
# co-occurence
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs))
ax.set_ylabel("Motif of interest")
ax.set_xlabel("Motif partner");
In [18]:
alt_seqs = PerturbSeqDataset(dfi_subset, seqs).load_all()
alt_preds = bpnet.predict(alt_seqs)
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True, batch_size=256)
In [ ]:
single_mut = NumpyDataset({t: {
    "pred": alt_preds[t],
    "imp": {
        "profile": alt_imp_scores_contrib[f"{t}/weighted"],
        "count": alt_imp_scores_contrib[f"{t}/count"],
    }} for t in tasks}, attrs={'index': 'row_idx'})

Double motif perturbation

In [ ]:
# construct the matrix of all interesting motif pairs
dfab = pd.concat([motif_pair_dfi(dfi_subset, motif_pair).assign(motif_pair='<>'.join(motif_pair)) 
                  for motif_pair in pairs], axis=0)
dfab['motif_pair_idx'] = np.arange(len(dfab))
In [ ]:
len(dfab)
In [ ]:
dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
double_alt_preds = bpnet.predict(dpdata_seqs)
double_alt_imp_scores = bpnet.imp_score_all(double_alt_preds, method='deeplift', aggregate_strand=True, batch_size=256)
double_alt_imp_contrib = {k: v * alt_seqs for k,v in double_alt_imp_scores.items()}
In [ ]:
double_mut = NumpyDataset({t: {
    "imp": {
        "profile": double_alt_imp_contrib[f"{t}/weighted"],
        "count": double_alt_imp_contrib[f"{t}/count"],
    }} for t in tasks}, attrs={'index': 'motif_pair_idx'})

Store to disk

In [31]:
dataset_dir = output_dir / 'perturbation-analysis'
dataset_dir.mkdir(exist_ok=True)
In [88]:
!du -sh {dataset_dir}/*
20M	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile/perturbation-analysis/dfab.csv.gz
18M	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile/perturbation-analysis/dfi_subset.csv.gz
5.0G	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile/perturbation-analysis/double_mut.h5
11G	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile/perturbation-analysis/ref.h5
7.5G	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/deeplift/profile/perturbation-analysis/single_mut.h5
In [89]:
%time o = NumpyDataset.load(dataset_dir / 'double_mut.h5')
CPU times: user 1min 1s, sys: 19.2 s, total: 1min 20s
Wall time: 1min 16s
In [68]:
!zcat {dataset_dir}/dfab.csv.gz | wc -l
71590
In [ ]:
# store all files to disk
%time ref.save(dataset_dir / 'ref.h5')
%time single_mut.save(dataset_dir / 'single_mut.h5')
%time double_mut.save(dataset_dir / 'double_mut.h5')
%time dfi_subset.to_csv(dataset_dir / 'dfi_subset.csv.gz', compression='gzip')
%time dfab.to_csv(dataset_dir / 'dfab.csv.gz', compression='gzip')
In [ ]:
# Load all files from disk
%time ref = NumpyDataset.load(dataset_dir / 'ref.h5')
%time single_mut = NumpyDataset.load(dataset_dir / 'single_mut.h5')
%time double_mut = NumpyDataset.load(dataset_dir / 'double_mut.h5')
%time dfi_subset = pd.read_csv(dataset_dir / 'dfi_subset.csv.gz')
%time dfab = pd.read_csv(dataset_dir / 'dfab.csv.gz')

Create a single dataset

In [ ]:
opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
In [130]:
def remove_edge_instances(dfab, profile_width=70, total_width=1000):
    half = profile_width // 2 + profile_width % 2
    return dfab[(dfab.pattern_center_x - half > 0) & (dfab.pattern_center_x + half < total_width)&
                (dfab.pattern_center_y - half > 0) & (dfab.pattern_center_y + half < total_width)]
In [369]:
from basepair.plot.profiles import  plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile, heatmap_stranded_profile
In [153]:
%matplotlib inline
paper_config()
In [ ]:
motif_pair_lpdata = {}
for motif_pair in pairs:
    motif_pair_name = "<>".join(motif_pair)
    dfab_subset = remove_edge_instances(dfab[dfab.motif_pair == motif_pair_name], profile_width=profile_width)
    pdata = ParturbationDataset(dfab_subset, ref, single_mut, double_mut, profile_width=profile_width)
    motif_pair_lpdata[motif_pair_name] = pdata.load_all(num_workers=1)
    
    # store also dfab
    motif_pair_lpdata[motif_pair_name]['dfab'] = dfab_subset
#     sort_idx = np.argsort(pdata.dfab.center_diff)
In [459]:
write_pkl(motif_pair_lpdata, dataset_dir / 'motif_pair_lpdata.pkl')
In [ ]:
# add also the double perturbations
for k, dfab in dfab_pairs.items():
    dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
    dalt_preds = bpnet.predict(dpdata_seqs)
    dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles, profile_mapping).load_all(num_workers=20)
    df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
    dfab = pd.concat([dfab, df_dpdata], axis=1)
    dfab_pairs[k] = dfab  # override the new dfab
In [1341]:
%tqdm_restart
In [1342]:
smpdata = PerturbDataset( dfi_subset, seqs, preds, profiles, ref_imp_scores_contrib,
                 alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib, 
                 ref_profiles,
                 profile_mapping)
In [1343]:
spdata = SingleMotifPerturbDataset(smpdata).load_all(num_workers=20)
100%|██████████| 3381/3381 [00:12<00:00, 278.50it/s]
In [1344]:
dfsm = pd.DataFrame(flatten(spdata), index=dfi_subset.index)
dfsm = pd.concat([dfsm, dfi_subset], axis=1)
In [ ]:
motif_pair = ['Nanog', 'Sox2']
In [ ]:
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
    print(f"{i+1}/{len(pairs)}")
    dfab = motif_pair_dfi(dfi_subset, motif_pair)
    opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
In [570]:
from basepair.config import test_chr

Take into account motif pairs

In [1314]:
motif_pair = ['Nanog', 'Klf4']
In [1315]:
dfab_pairs_bak = deepcopy(dfab_pairs)
In [ ]:
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
    print(f"{i+1}/{len(pairs)}")
    dfab = motif_pair_dfi(dfi_subset, motif_pair)
    opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
    dfab_sm = pd.DataFrame(flatten(opdata), index=dfab.index)
    dfab_sm = pd.concat([dfab, dfab_sm], axis=1)
    dfab_pairs["<>".join(motif_pair)] = dfab_sm
In [941]:
dfab_pairs_bak = deepcopy(dfab_pairs)
In [1000]:
dfab_pairs = deepcopy(dfab_pairs_bak)
In [1002]:
%tqdm_restart
In [ ]:
# add also the double perturbations
for k, dfab in dfab_pairs.items():
    dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
    dalt_preds = bpnet.predict(dpdata_seqs)
    dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles).load_all(num_workers=20)
    df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
    dfab = pd.concat([dfab, df_dpdata], axis=1)
    dfab_pairs[k] = dfab  # override the new dfab
In [1347]:
# append A|dA
dfsm_prefixed_x = dfsm.copy()
dfsm_prefixed_y = dfsm.copy()
dfsm_prefixed_x.columns = ["dx_x_" + c for c in dfsm_prefixed_x.columns]
dfsm_prefixed_y.columns = ["dy_y_" + c for c in dfsm_prefixed_y.columns]
for k, dfab in dfab_pairs.items():
    dfab = pd.merge(dfab, dfsm_prefixed_x, how='left', left_on='row_idx_x', right_on='dx_x_row_idx')
    dfab = pd.merge(dfab, dfsm_prefixed_y, how='left', left_on='row_idx_y', right_on='dy_y_row_idx')
    dfab_pairs[k] = dfab  # override the new dfab
In [1348]:
# store the pairs
write_pkl(dfab_pairs, modisco_dir / 'dfab_pairs.pkl')
In [20]:
dfab_pairs = read_pkl(modisco_dir / 'dfab_pairs.pkl')

TODO - show also the importance of other factors

  • shall we display it using 4 differnet heatmaps or shall we always just use a single metric to display it?

    • [ ] which metric to use
      • is the counts at the profile enough?
      • can we have local importance scores?
        • aggregate importance for prediction inside and outside of the profile
          • make the interval rather tight
  • focus on Oct4-Sox2 interactions

In [474]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [478]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [482]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [611]:
mkdir -p {ddir}/figures/modisco/spacing/preturb
In [22]:
figures = Path(f"{ddir}/figures/modisco/spacing/preturb")
In [25]:
import warnings
warnings.filterwarnings("ignore")
In [ ]:
from basepair.exp.chipnexus.perturb import plot_scatter, compute_features, plt_diag, plot_scatter, plot_pairs

Perturbation scatterplots

Using pseudo-counts

In [32]:
plot_pairs(dfab_pairs, pairs, ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match'], pseudo_count_quantile=.2, variable=None, pval=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)

TODO

  • [ ] scatterplot / parallel components / boxplot the importance scores (count + profile) for all the 4 factors on per-TF run
  • hirearchy - Klf4 is last
    • how does it work?
  • [ ] graph model with the motifs
    • can we predict it just using Klf4?

Without pseudo-counts

In [1419]:
plot_pairs(dfab_pairs, pairs, ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match'], variable=None, pval=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)

Color by distance

In [1420]:
cvar = 'dist'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)
In [1421]:
plot_features = ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match']

Open questions

  • why is the sign inverted?

Color by importance

In [1422]:
cvar = 'imp'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)

Color by strand

In [1423]:
cvar = 'strand'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)

Heatmaps

In [47]:
total_examples = len(dfi.example_idx.unique())
total_examples
Out[47]:
64434
In [814]:
# TODO - generalize this table to also have the diagonal in
In [822]:
motif_pair = ['Nanog', 'Nanog']
In [828]:
dfiab = dfab_pairs["<>".join(motif_pair)]
x_total = dfi_subset[dfi_subset.pattern_name == motif_pair[0]].shape[0]
xy_total = len(dfiab[dfiab.center_diff < 150].row_idx_x.unique())
In [829]:
x_total  # total number of instances of motif A
Out[829]:
47990
In [832]:
xy_total
Out[832]:
13321
In [837]:
dfab_pairs_filt = {k: v[v.center_diff < 150] for k,v in dfab_pairs.items()}
In [ ]:
from basepair.exp.chipnexus.perturb import plot_mutation_heatmap
from basepair.exp.chipnexus.spacing import co_occurence_matrix, fisher_test_coc, coocurrence_plot

Co-occurence

In [59]:
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs), ax=ax)
In [130]:
df = pd.read_csv(f"{ddir}/processed/chipnexus/external-data.tsv", sep='\t')
dfs = df[df.assay.isin(['PolII', 'H3K27ac'])]
In [131]:
import pybedtools
from pybedtools import BedTool
from basepair.extractors import MultiAssayExtractor
from basepair.data import NumpyDataset
In [114]:
df_regions = dfi_subset[['example_chrom', 'example_start', 'example_end', 'example_idx']].drop_duplicates()
bt = BedTool.from_dataframe(df_regions)

extractor = MultiAssayExtractor(dfs, None, use_strand=False, n_jobs=10)

regions = extractor.extract(list(bt))
In [133]:
r = NumpyDataset(regions)
dfc = pd.DataFrame(r.aggregate(np.sum, axis=1))
dfc['example_idx'] = df_regions['example_idx'].values
In [134]:
dfc.head()
Out[134]:
H3K27ac PolII example_idx
0 12794.0 2184.0 1
1 220274.0 94643.0 2
2 130652.0 86907.0 3
3 76793.0 60257.0 4
4 26238.0 13179.0 5
In [135]:
high = np.quantile(dfc.H3K27ac, .9)
In [136]:
high
Out[136]:
60299.5
In [137]:
dfc.H3K27ac.plot.hist(30);
In [140]:
# Add H3K27 ac to the table
dfi_subset = pd.merge(dfi_subset, dfc, on='example_idx', how='left')
In [143]:
np.quantile(dfi_subset[['example_idx', 'H3K27ac']].drop_duplicates().H3K27ac, .9)
Out[143]:
60299.5
In [153]:
ls {figures}
pairwise_all.close,imp=high,color=match.pdf   pairwise_all.color=strand.png
pairwise_all.close,imp=high,color=match.png   pairwise_all.mid-range.pdf
pairwise_all.close,imp=high,color=strand.pdf  pairwise_all.mid-range.png
pairwise_all.close,imp=high,color=strand.png  pairwise_all.pdf
pairwise_all.close,imp=high.pdf               pairwise_all.png
pairwise_all.close,imp=high.png               pairwise_all.wilcox_pval.pdf
pairwise_all.color=dist.pdf                   pairwise_all.wilcox_pval.png
pairwise_all.color=dist.png                   pairwise_imp_inside.pdf
pairwise_all.color=imp.pdf                    pairwise_imp_inside.png
pairwise_all.color=imp.png                    pairwise_match_inside.pdf
pairwise_all.color=match.pdf                  pairwise_match_inside.png
pairwise_all.color=match.png                  pairwise_pred_inside.pdf
pairwise_all.color=strand.pdf                 pairwise_pred_inside.png
In [163]:
!mkdir -p {figures}/../co-occurence/
In [164]:
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/all.pdf")
plt.savefig(f"{figures}/../co-occurence/all.png")
In [165]:
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset[dfi_subset.H3K27ac > np.quantile(dfc.H3K27ac, .9)], list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/H3K27ac>90percentile.pdf")
plt.savefig(f"{figures}/../co-occurence/H3K27ac>90percentile.png")
In [166]:
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset[dfi_subset.PolII > np.quantile(dfc.PolII, .9)], list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/PolII>90percentile.pdf")
plt.savefig(f"{figures}/../co-occurence/PolII>90percentile.png")
In [227]:
dfi_subset.pattern_name.head()
Out[227]:
0    Oct4-Sox2
1    Oct4-Sox2
2    Oct4-Sox2
3    Oct4-Sox2
4    Oct4-Sox2
Name: pattern_name, dtype: object
In [ ]:
dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])]
In [237]:
dfi_subset[dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])].groupby(["example_idx", 'pattern_name').pattern_name.size()
  File "<ipython-input-237-9329043ef50d>", line 1
    dfi_subset[dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])].groupby(["example_idx", 'pattern_name').pattern_name.size()
                                                                                                          ^
SyntaxError: invalid syntax
In [243]:
counts = pd.pivot_table(dfi_subset[dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])], index='example_idx', columns='pattern_name', values='H3K27ac', aggfunc=len, fill_value=0)
In [244]:
c = counts > 0
In [261]:
values = c.Klf4.map({False:"", True: "Klf4"}) + c['Oct4-Sox2'].map({False:"", True: "Oct4-Sox2"})
In [262]:
values.value_counts()
Out[262]:
Klf4             20095
Oct4-Sox2        11258
Klf4Oct4-Sox2     4441
dtype: int64
In [253]:
pd.where(c.Klf4, "", "Klf4") + "a"
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-253-44e1bc359804> in <module>
----> 1 pd.where(c.Klf4, "", "Klf4") + "a"

AttributeError: module 'pandas' has no attribute 'where'
In [258]:
dfs = dfi_subset[['example_idx', "H3K27ac", "PolII"]].drop_duplicates()
In [259]:
dfs = dfs.set_index("example_idx")
In [267]:
dfsj = pd.DataFrame({"feat": values}).join(dfs)
In [268]:
dfsj.unstack(id_vars=)
Out[268]:
feat H3K27ac PolII
example_idx
1 Oct4-Sox2 12794.0 2184.0
2 Oct4-Sox2 220274.0 94643.0
3 Oct4-Sox2 130652.0 86907.0
... ... ... ...
98413 Klf4 14296.0 2846.0
98417 Klf4 2000.0 1264.0
98418 Klf4 11231.0 2952.0

35794 rows × 3 columns

In [273]:
ggplot(aes(x='feat', y='H3K27ac'), data=dfsj) + geom_boxplot() + scale_y_continuous(trans='log10') + geom_violin()
Out[273]:
<ggplot: (-9223363268848167957)>
In [272]:
ggplot(aes(x='feat', y='PolII'), data=dfsj) + geom_boxplot() + scale_y_continuous(trans='log10')
Out[272]:
<ggplot: (8768006830414)>
In [260]:
values
Out[260]:
example_idx
1        Oct4-Sox2
2        Oct4-Sox2
3        Oct4-Sox2
           ...    
98413         Klf4
98417         Klf4
98418         Klf4
Length: 35794, dtype: object
In [245]:
c
Out[245]:
pattern_name Klf4 Oct4-Sox2
example_idx
1 False True
2 False True
3 False True
... ... ...
98413 True False
98417 True False
98418 True False

35794 rows × 2 columns

Effect of partner motif perturbation

In [1295]:
figsize_single = get_figsize(.5, aspect=1)
fig, axes = plt.subplots(1, len(features), figsize=(figsize_single[0]*len(features), figsize_single[0]))
for i, (feat, ax) in enumerate(zip(features, axes)):
    if i >= 2 :
        max_frac = 1.5
        # smaller y-scale
    else:
        max_frac = 2
    plot_mutation_heatmap(dfab_pairs, pairs, list(motifs), feat, ax=ax, max_frac=max_frac)
plt.tight_layout()

TODO

  • [x] Re-order plots (TF pair X metric)
  • [X] Add deeplift profile * total counts
  • [x] Add deeplift total counts importances
  • [X] color the points according to the pairwise distance
    • < 35, 35-70, 70-150
    • strand (+-, --, ...)
    • strength of A, strength of B (high/low affinity)
  • [x] Add Wilcoxon test to the plot.
  • [x] compute A | dA & dB and add 2 new importance metrics to the plot (Including wilcoxon test)
    • add (A|dB - A|dA&dB ) / ( A - A|dA)
    • add (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA) plot
  • [x] Plot the heatmap of average effect (TFxTF, color=alt/ref, text=signif stars)
  • [~] generate 3 heatmaps
    • [X] co-occurence
    • [X] perturbation effect
    • [ ] 10bp periodicity
  • [ ] merge all into Figure 5

Open questions

  • [ ] can we ignore the strand and just aggregate all the results across 4 (or 3) strand combinations?
    • test: add them up and see if you are still getting qualitatively the same results
      • overplot or plot very close to each other (4 strands + aggregate):
        • simulations data (to see subtle differences)
        • histogram (as a chart step)
        • observed counts in the genome
  • [ ] is the simulated motif aligned with observed seqlet?
    • if not, then re-generate all the simulation data
  • [ ] how to test for significance for motifs of the same kind?

TODO

  • [x] why are the alt importance scores negative? now fixed
    • I was using the hypothetical contribs instead of contribs
  • [x] Dissect the contribution from flanks on counts / profile
  • [x] figure out the dependency graph
  • shall we weight the imp-scores with the total counts?

Analyze the dataset

  • for each motif individually, study the impact of mutations on the proximal and distal counts
    • how much of the local profile can be explained by the motif?
      • note: that point could be circular since the architecture is designed the 'local' way
  • for each motif pair, study the impact of the mutations on other motifs
    • measure the change in the following variables when perturbing motif A:
      • counts inside (A, B)x[ref, alt]
      • profile match (A,B) x [ref alt]
      • counts outside (A)x[ref, alt]
      • importance score (B)x[alt], (A,B)x[ref]
      • pairwise distance
      • strand orientation
  • [ ] compile the dataset such that you can extract the dataset for a probabilistic graphical model

Questions:

  • How much does the importance scores of B decrease when perturbing A?
  • How much do the profile counts of B decrease when perturbing A?

Stratification. How are the above questions influenced by

  • distance or orientation?
  • presence of other motifs (is there some redundancy)?
  • weak / strong affinity motifs of A?

Final goal:

  • paper figures for each motif pair
  • graph showing different connections between core motifs
    • visualize edges in networkx

Other plots

In [211]:
# Plot the profile instances
from basepair.modisco.pattern_instances import dfi2seqlets, annotate_profile
from basepair.modisco.results import Seqlet, resize_seqlets
from basepair.plot.profiles import extract_signal
from basepair.plot.profiles import  plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile

dfi_subset = (dfi.query('match_weighted_p > .2')
                 .query('imp_weighted_p > 0.0')
                 .query('pattern_name=="Oct4-Sox2"'))
seqlets = dfi2seqlets(dfi_subset)
seqlets = resize_seqlets(seqlets, 70, seqlen=1000)
seqlet_profiles = {k: extract_signal(v, seqlets) for k,v in profiles.items()}

multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));

Explore ISM scores

In [1335]:
k = 'Oct4-Sox2<>Nanog'
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]

cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)

match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
             (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
            (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
In [1350]:
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))

variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
    dfab_sm = dfab_sma[eval(variable) == cat]
    plt.subplot(2, nc, 1)
    plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dx&dy")
    plt.ylabel("y_total | dx")
    plt_diag([0, 500])
    plt.title(k);

    plt.subplot(2, nc, 2)
    plt.scatter((dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total")
    plt.ylabel("y_total | dx")
    plt_diag([0, 1000])

    plt.subplot(2, nc, 3)
    plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total | dx")
    plt_diag([0, 1000])

    plt.subplot(2, nc, 4)
    plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total")
    plt_diag([0, 1000])

    plt.subplot(2, nc, nc+1)
    plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dx&dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 200])
    plt.title(k);

    plt.subplot(2, nc, nc+2)
    plt.scatter(dfab_sm.xy_ref_pred_inside , dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 400])

    plt.subplot(2, nc, nc+3)
    plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 400])

    plt.subplot(2, nc, nc+4)
    plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_ref_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint")
    plt_diag([0, 400])
    
    plt.subplot(2, nc, nc+5)
    plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_ref_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy&dx")
    plt.ylabel("y_footprint")
    plt_diag([0, 400])

plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()

Log-scale

In [1354]:
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))

variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
    dfab_sm = dfab_sma[eval(variable) == cat]
    plt.subplot(2, nc, 1)
    plt.scatter(np.log(dfab_sm.dxy_y_pred_total), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dx&dy")
    plt.ylabel("y_total | dx")
    plt_diag([4, 7])
    plt.title(k);

    plt.subplot(2, nc, 2)
    plt.scatter(np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total")
    plt.ylabel("y_total | dx")
    plt_diag([4, 7])

    plt.subplot(2, nc, 3)
    plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total | dx")
    plt_diag([4, 7])

    plt.subplot(2, nc, 4)
    plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total")
    plt_diag([4, 7])

    plt.subplot(2, nc, nc+1)
    plt.scatter(np.log(dfab_sm.dxy_y_pred_inside), np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dx&dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([2, 6])
    plt.title(k);

    plt.subplot(2, nc, nc+2)
    plt.scatter(np.log(dfab_sm.xy_ref_pred_inside) , np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint")
    plt.ylabel("y_footprint | dx")
    plt_diag([2, 6])

    plt.subplot(2, nc, nc+3)
    plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside), np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([2, 6])

    plt.subplot(2, nc, nc+4)
    plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside), np.log(dfab_sm.xy_ref_pred_inside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint")
    plt_diag([2, 6])
    
    plt.subplot(2, nc, nc+5)
    plt.scatter(np.log(dfab_sm.dxy_y_pred_inside), np.log(dfab_sm.xy_ref_pred_inside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy&dx")
    plt.ylabel("y_footprint")
    plt_diag([2, 6])

plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
In [1385]:
# Note: All are on the log-scale
fig = plt.figure(figsize=get_figsize(.5))
x_alt = np.log(dfab_sm.xy_alt_pred_inside) - np.log(dfab_sm.dxy_y_pred_inside) # total counts | dB - dxy_x_pred_total
x_ref = np.log(dfab_sm.xy_ref_pred_inside) - np.log(dfab_sm.dy_y_alt_pred_inside)
plt.scatter(x_ref, x_alt, s=2, alpha=0.2)
plt.xlabel("y - y|dy")
plt.ylabel("y|dx - y|dx&dy")
plt_diag([-.5, 2])
fig = plt.figure(figsize=get_figsize(.5))
frac = x_alt / x_ref
plt.hist(frac[(frac<2) & (frac > 0)], 30);
plt.ylabel("Frequency");
plt.xlabel("(y|dx - y|dx&dy) / (y - y|dy)\n(y=Nanog, x=Oct4-Sox2)");
In [ ]:
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside, dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside, s=2, alpha=0.2)

TODO

  • [x] try ratio of ratios
  • [x] work on the log scale instead of the natural scale for the counts
  • [/] do the ISM score for exactly the same objective as DeepLIFT
  • [ ] explore differnet colors for match < .2
In [1355]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_imp'
    dfab_sm = dfab_sma
    for j, ax in enumerate(axes[i]):
        if j == 0:
            if i == 0:
                ax.set_title("Corrected total counts (ratio)")
            # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
#             x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
#             x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
#             y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
#             y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)/dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)/(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)/ dfab_sm.dxy_x_pred_total
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)/(dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
        if j == 1:
            if i == 0:
                ax.set_title("Corrected footprint counts (ratio)")
            # (A|dB - A|dA&dB)/(A - A|dA)
#             x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
#             x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
#             y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
#             y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            
            x_alt = dfab_sm.xy_alt_pred_inside / dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
            x_ref = dfab_sm.xy_ref_pred_inside / dfab_sm.dy_y_alt_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside / dfab_sm.dxy_x_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside / dfab_sm.dx_x_alt_pred_inside
        if j == 2:
            if i == 0:
                ax.set_title("Log Corrected total counts")
            x_alt = np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - np.log(dfab_sm.dxy_y_pred_total) # total counts | dB - dxy_x_pred_total
            x_ref = np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
            y_alt = np.log(dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - np.log(dfab_sm.dxy_x_pred_total)
            y_ref = np.log(dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - np.log(dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
        if j == 3:
            if i == 0:
                ax.set_title("Log corrected footprint counts (ratio)")
            x_alt = np.log(dfab_sm.xy_alt_pred_inside) - np.log(dfab_sm.dxy_y_pred_inside) # total counts | dB - dxy_x_pred_total
            x_ref = np.log(dfab_sm.xy_ref_pred_inside) - np.log(dfab_sm.dy_y_alt_pred_inside)
            y_alt = np.log(dfab_sm.yx_alt_pred_inside) - np.log(dfab_sm.dxy_x_pred_inside)
            y_ref = np.log(dfab_sm.yx_ref_pred_inside) - np.log(dfab_sm.dx_x_alt_pred_inside)
        elif j == 4:
            if i == 0:
                ax.set_title("Profile importance")
            x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
            x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
            y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
            y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
        elif j == 5:
            if i == 0:
                ax.set_title("Norm. profile importance")
            x_alt = dfab_sm.xy_alt_imp_inside
            x_ref = dfab_sm.xy_ref_imp_inside
            y_alt = dfab_sm.yx_alt_imp_inside
            y_ref = dfab_sm.yx_ref_imp_inside
        elif j == 6:
            if i == 0:
                ax.set_title("Deeplift w.r.t. counts")
            x_alt = dfab_sm.xy_alt_impcount_inside
            x_ref = dfab_sm.xy_ref_impcount_inside
            y_alt = dfab_sm.yx_alt_impcount_inside
            y_ref = dfab_sm.yx_ref_impcount_inside
        plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
        #if j == nf - 1 and i == 0:
        #    ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            # ax.legend(labels=list(cat_strand.categories))
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)

Decision

  • The ISM approach seems to be too noisy even when using the log-scale