Goal

  • Visualize what happens when you permute the motifs at the oct4 enhancer

Tasks

  • [x] get the trimmed seqlet locations
  • [x] sample the sequences from the background distribution
  • [x] make the prediction
    • show also the importance scores
In [1]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
Using TensorFlow backend.
In [9]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
In [77]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
In [3]:
create_tf_session(0)
Out[3]:
<tensorflow.python.client.session.Session at 0x7fab3abadeb8>
In [4]:
bpnet = BPNet.from_mdir(model_dir)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2018-12-31 07:55:38,327 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2018-12-31 07:55:50,911 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [7]:
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
In [8]:
pred = bpnet.predict_intervals([interval], imp_method='deeplift')[0]
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  False
In [70]:
def plot_region(interval, variants=None, ):
    pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
    viz_dict = OrderedDict(flatten_list([[
                        (f"{task} Pred", pred['pred'][task]),
                        (f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        # (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                    ] for task_idx, task in enumerate(bpnet.tasks)]))

    viz_dict = filter_tracks(viz_dict, [420, 575])

    # Hard-code the range
    fmax = {'Imp profile': 0.265742, 'Pred': 34.442215}
    fmin = {'Imp profile': -0.16751188, 'Pred': 0.07202636}
    #fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}
    #fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}


    ylim = []
    for k in viz_dict:
        f = k.split(" ", 1)[1]
        if "Imp" in f:
            ylim.append((fmin[f], fmax[f]))
        else:
            ylim.append((0, fmax[f]))

    fig = plot_tracks(viz_dict,
                      #seqlets=shifted_seqlets,
                      title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                      fig_height_per_track=0.5,
                      rotate_y=0,
                      fig_width=get_figsize(frac=1)[0],
                      ylim=ylim,
                      legend=False)
    return fig
In [71]:
plot_region(interval);
  • [x] get the trimmed seqlet locations
  • [x] sample the sequences from the background distribution
  • [x] make the prediction
    • show also the importance scores
In [21]:
from basepair.exp.paper.config import motifs
In [22]:
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
In [23]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
In [24]:
dfi = load_instances(modisco_dir  / 'instances.parq', motifs=motifs, dedup=True)
number of deduplicatd instances: 1594551 (35.14740608536545%)
In [25]:
dfi.head()
Out[25]:
pattern example_idx pattern_start pattern_end strand pattern_len pattern_center match_weighted match_weighted_p match_weighted_cat match_max match_max_task imp_weighted imp_weighted_p imp_weighted_cat imp_max imp_max_task seq_match seq_match_p seq_match_cat match/Klf4 match/Nanog match/Oct4 match/Sox2 imp/Klf4 imp/Nanog imp/Oct4 imp/Sox2 example_chrom example_start example_end example_strand example_interval_from_task pattern_short pattern_name pattern_start_abs pattern_end_abs
0 metacluster_0/pattern_0 0 104 119 + 15 111 0.2469 0.0010 low 0.4484 Oct4 0.0365 NaN NaN 0.0524 Oct4 3.9825 0.0085 low 0.2654 -0.3593 0.4484 0.3725 0.0235 0.0246 0.0524 0.0313 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482648 143482663
1 metacluster_0/pattern_0 0 263 278 - 15 270 0.2518 0.0011 low 0.2893 Oct4 0.0217 NaN NaN 0.0342 Nanog 1.1066 0.0007 low 0.1618 0.2170 0.2893 0.2803 0.0143 0.0342 0.0200 0.0203 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482807 143482822
2 metacluster_0/pattern_0 0 282 297 + 15 289 0.3466 0.0069 low 0.3870 Sox2 0.0272 NaN NaN 0.0296 Sox2 5.0100 0.0188 low 0.2257 0.3454 0.3723 0.3870 0.0233 0.0258 0.0279 0.0296 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482826 143482841
3 metacluster_0/pattern_0 0 445 460 + 15 452 0.2399 0.0008 low 0.2469 Sox2 0.1852 0.0008 high 0.4145 Nanog 1.7775 0.0012 low 0.2458 0.2408 0.2314 0.2469 0.1234 0.4145 0.0949 0.1896 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482989 143483004
4 metacluster_0/pattern_0 0 475 490 - 15 482 0.2903 0.0025 low 0.3322 Oct4 0.2889 0.0179 high 0.5196 Nanog 3.3390 0.0057 low 0.2092 0.2681 0.3322 0.2991 0.1613 0.5196 0.2049 0.3241 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143483019 143483034
In [ ]:
chr17, 35503550
In [27]:
# get all instances in this region
dfi_subset = dfi.query('example_chrom == "chr17" & pattern_start_abs > 35503550 + 420 & pattern_end_abs < 35503550 + 575')
In [28]:
dfi_subset
Out[28]:
pattern example_idx pattern_start pattern_end strand pattern_len pattern_center match_weighted match_weighted_p match_weighted_cat match_max match_max_task imp_weighted imp_weighted_p imp_weighted_cat imp_max imp_max_task seq_match seq_match_p seq_match_cat match/Klf4 match/Nanog match/Oct4 match/Sox2 imp/Klf4 imp/Nanog imp/Oct4 imp/Sox2 example_chrom example_start example_end example_strand example_interval_from_task pattern_short pattern_name pattern_start_abs pattern_end_abs
26 metacluster_0/pattern_0 2 498 513 + 15 505 0.5649 0.6728 high 0.5866 Sox2 1.5915 0.9989 low 2.0475 Oct4 10.2468 0.5427 medium 0.5583 0.5077 0.5804 0.5866 1.0633 0.9293 2.0475 1.7508 chr17 35503551 35504551 * Oct4 m0_p0 Oct4-Sox2 35504049 35504064
760152 metacluster_0/pattern_1 2 499 521 - 22 510 0.2606 0.0292 low 0.3322 Oct4 1.7406 1.0000 low 2.2173 Oct4 7.0253 0.2856 low 0.2900 0.1969 0.3322 0.2928 1.2676 1.5616 2.2173 2.0605 chr17 35503551 35504551 * Oct4 m0_p1 Sox2 35504050 35504072
5806380 metacluster_1/pattern_0 2 427 437 + 10 432 0.4595 0.0314 low 0.5568 Sox2 0.2391 0.0101 high 0.2699 Klf4 8.6701 0.3282 low 0.4641 0.3138 0.4473 0.5568 0.2699 0.0635 0.1076 0.1024 chr17 35503551 35504551 * Oct4 m1_p0 Klf4 35503978 35503988
5806381 metacluster_1/pattern_0 2 551 561 + 10 556 0.4719 0.0408 low 0.5052 Klf4 0.1506 0.0002 high 0.1576 Klf4 6.1363 0.0534 low 0.5052 0.2891 0.2286 0.3862 0.1576 0.1288 0.0807 0.1300 chr17 35503551 35504551 * Oct4 m1_p0 Klf4 35504102 35504112
12636696 metacluster_2/pattern_0 2 475 484 - 9 479 0.4232 0.1721 low 0.4616 Nanog 0.2685 0.7618 low 0.3171 Nanog 3.8011 0.1262 low 0.2817 0.4616 0.3874 0.3832 0.1955 0.3171 0.1451 0.1472 chr17 35503551 35504551 * Oct4 m2_p0 Nanog 35504026 35504035
12636697 metacluster_2/pattern_0 2 516 525 - 9 520 0.6456 0.9035 high 0.6990 Nanog 0.6106 0.9960 low 0.7815 Nanog 5.0813 0.3002 low 0.5602 0.6990 0.4753 0.5479 0.2167 0.7815 0.2457 0.3018 chr17 35503551 35504551 * Oct4 m2_p0 Nanog 35504067 35504076
12636698 metacluster_2/pattern_0 2 542 551 - 9 546 0.2560 0.0038 low 0.2989 Nanog 0.2548 0.7233 low 0.2963 Nanog 3.1094 0.0624 low 0.1332 0.2989 0.1664 0.2070 0.1627 0.2963 0.1347 0.2010 chr17 35503551 35504551 * Oct4 m2_p0 Nanog 35504093 35504102
In [31]:
dfi_subset['rel_center'] = dfi_subset['pattern_center'] + dfi_subset['example_start']- (35503550 +420)
In [36]:
dfi_subset['pattern_width'] = dfi_subset['pattern_end'] - dfi_subset['pattern_start']
In [37]:
dfi_subset[['pattern_name', 'strand', 'rel_center', 'pattern_width']]
Out[37]:
pattern_name strand rel_center pattern_width
26 Oct4-Sox2 + 86 15
760152 Sox2 - 91 22
5806380 Klf4 + 13 10
5806381 Klf4 + 137 10
12636696 Nanog - 60 9
12636697 Nanog - 101 9
12636698 Nanog - 127 9
In [38]:
oct_sox_row = dfi_subset.loc[26]
In [39]:
oct_sox_row
Out[39]:
pattern          metacluster_0/pattern_0
example_idx                            2
pattern_start                        498
                          ...           
rel_center                            86
pattern_widtn                         15
pattern_width                         15
Name: 26, Length: 40, dtype: object
In [41]:
from basepair.extractors import Variant, extract_seq
In [42]:
trimmed_interval = pybedtools.create_interval_from_list(['chr17', 35503550 +420, 35503550 +575])
In [48]:
seq = extract_seq(trimmed_interval, None, bpnet.fasta_file)
In [52]:
# Oct4-Sox2
seq[80:91]
Out[52]:
'ATGCATAACAA'
In [53]:
# Nanog
seq[98:102]
Out[53]:
'TGAT'
In [59]:
import random
random_os = ''.join(random.choices("ACGT", k=11))
random_nanog = ''.join(random.choices("ACGT", k=4))
In [60]:
random_os
Out[60]:
'GTTCGCTCGTG'
In [64]:
random_nanog
Out[64]:
'ATGT'

Original

In [72]:
plot_region(interval);

Disrupt Oct-Sox2

In [73]:
plot_region(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);

Disrupt Nanog

In [74]:
plot_region(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);

Investigating the importance scores

  • plot the diff
In [87]:
def plot_region_diff(interval, variants):
    ref_pred = bpnet.predict_intervals([interval], variants=None, imp_method='deeplift')[0]
    pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
    
    viz_dict = OrderedDict(flatten_list([[
                        (f"{task} Pred", pred['pred'][task]),
                        (f"{task} ref - alt pred", ref_pred['pred'][task] - pred['pred'][task]),
                        (f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        (f"{task} ref - alt Imp", ref_pred['imp_score'][f"{task}/weighted"] * ref_pred['seq'] - pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        # (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                    ] for task_idx, task in enumerate(bpnet.tasks)]))

    viz_dict = filter_tracks(viz_dict, [420, 575])

    # Hard-code the range
    fmax = {'Imp profile': 0.265742, 'Pred': 34.442215, }
    fmin = {'Imp profile': -0.16751188, 'Pred': 0.07202636}
    #fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}
    #fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}


    ylim = []
    for k in viz_dict:
        f = k.split(" ", 1)[1]
        if "Imp" in f:
            ylim.append((fmin['Imp profile'], fmax['Imp profile']))
        else:
            ylim.append((-fmax['Pred'], fmax['Pred']))

    fig = plot_tracks(viz_dict,
                      #seqlets=shifted_seqlets,
                      title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                      fig_height_per_track=0.5,
                      rotate_y=0,
                      fig_width=get_figsize(frac=1)[0],
                      ylim=ylim,
                      legend=False)
    return fig
In [88]:
plot_region_diff(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);
In [89]:
plot_region_diff(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);