In [2]:
from collections import OrderedDict
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
imp_score = 'profile/wn'

motifs = OrderedDict([
    ("Oct4-Sox2", 'Oct4/m0_p0'),
    ("Oct4", 'Oct4/m0_p1'),
    # ("Strange-sym-motif", 'Oct4/m0_p5'),
    ("Sox2", 'Sox2/m0_p1'),
    ("Nanog", 'Nanog/m0_p1'),
    ("Zic3", 'Nanog/m0_p2'),
    ("Nanog-partner", 'Nanog/m0_p4'),
    ("Klf4", 'Klf4/m0_p0'),
])

Goal

  • Visualize what happens when you permute the motifs at the oct4 enhancer

Tasks

  • [x] get the trimmed seqlet locations
  • [x] sample the sequences from the background distribution
  • [x] make the prediction
    • show also the importance scores
In [3]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
from basepair.exp.paper.config import *
from basepair.extractors import Variant, extract_seq
hv.extension('bokeh')
Using TensorFlow backend.
In [4]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
In [5]:
create_tf_session(0)
Out[5]:
<tensorflow.python.client.session.Session at 0x7fb182b4cda0>
In [6]:
# figures dir
model_dir = models_dir / exp
fdir = Path(f'{ddir}/figures/modisco/{exp}/in-vivo-perturb/oct4-enhancer')
In [7]:
!mkdir -p {fdir}
In [8]:
from basepair.seqmodel import SeqModel
In [9]:
ds = DataSpec.load(rdir / 'src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml')
In [10]:
# Get counts
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])

obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
In [81]:
(420+60 + 35503550, 420+120 + 35503550)
Out[81]:
(35504030, 35504090)
In [11]:
from basepair.BPNet import BPNetSeqModel

bpnet = BPNetSeqModel.from_mdir(model_dir)
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
2019-04-08 04:55:45,555 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
WARNING:tensorflow:From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
2019-04-08 04:55:55,532 [WARNING] From /users/avsec/bin/anaconda3/envs/chipnexus/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
In [12]:
# input_seqlen = 1000 - bpnet.body.get_len_change()  - bpnet.heads[0].net.get_len_change()
In [13]:
from genomelake.extractors import FastaExtractor
In [ ]:
fe = FastaExtractor(ds.fasta_file)

seq = fe([resize_interval(interval, input_seqlen)])
In [14]:
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
In [15]:
# Hack
bpnet.fasta_file = ds.fasta_file
bpnet.bias_model = None
In [16]:
pred = bpnet.predict_intervals([interval], imp_method='deeplift')[0]
WARNING:tensorflow:From /users/avsec/workspace/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
2019-04-08 04:56:05,756 [WARNING] From /users/avsec/workspace/basepair/basepair/heads.py:323: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.
Instructions for updating:
dim is deprecated, use axis instead
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
DeepExplain: running "deeplift" explanation method (5)
Model with multiple inputs:  True
In [69]:
def to_neg(track):
    track = track.copy()
    track[:, 1] = - track[:, 1]
    return track
def plot_region(interval, variants=None, ):
    pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
    viz_dict = OrderedDict(flatten_list([[
                        (f"{task} Pred", to_neg(pred['pred'][task])),
                        (f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        # (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                    ] for task_idx, task in enumerate(['Oct4', 'Nanog'])]))

    viz_dict = filter_tracks(viz_dict, [420+60, 420+120])

    # Hard-code the range
    fmax = {'Imp profile': 0.265742, 'Pred': 6.5}
    fmin = {'Imp profile': -0.16751188, 'Pred': -6.5}

    ylim = []
    for k in viz_dict:
        f = k.split(" ", 1)[1]
        if "Imp" in f:
            ylim.append((fmin[f], fmax[f]))
        else:
            ylim.append((fmin[f], fmax[f]))
            
            
    colors = []
    for task in ['Oct4', 'Nanog']:
        colors.append((tf_colors[task], tf_colors[task] + "80"))  # 80 add alpha=0.5
        colors.append(None)

    fig = plot_tracks(viz_dict,
                      #seqlets=shifted_seqlets,
                      title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                      fig_height_per_track=0.5,
                      rotate_y=0,
                      fig_width=get_figsize(frac=.5)[0],
                      color=colors,
                      ylim=ylim,
                      use_spine_subset=True,
                      legend=False)
    sns.despine(top=True, right=True, left=False, bottom=True)
    return fig

Original

In [76]:
fdir
Out[76]:
PosixPath('/users/avsec/workspace/basepair/data/figures/modisco/nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE/in-vivo-perturb/oct4-enhancer')
In [77]:
plot_region(interval);
plt.savefig(fdir / 'wt.pdf')

Disrupt Oct-Sox2

In [78]:
plot_region(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);
plt.savefig(fdir / 'dOct4-Sox2.pdf')

Disrupt Nanog

In [79]:
plot_region(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);
plt.savefig(fdir / 'dNanog.pdf')

TODO

  • [X] Remove Sox2 and Klf4
  • [X] Crop only to the central 60 bp
  • [X] Reverse strand on the negative axis
  • [X] Color predicted profiles with TF-specific colors
  • [X] Try to bring both Oct4 and Nanog predicted tracks to the same axis
  • [ ] Try to put into a single PDF plot (Multiple columns)
In [64]:
tf_colors
Out[64]:
{'Klf4': '#357C42',
 'Oct4': '#9F1D20',
 'Sox2': '#3A3C97',
 'Nanog': '#9F8A31',
 'Esrrb': '#30BDC4'}

Investigating the importance scores

  • plot the diff
In [50]:
def plot_region_diff(interval, variants):
    ref_pred = bpnet.predict_intervals([interval], variants=None, imp_method='deeplift')[0]
    pred = bpnet.predict_intervals([interval], variants=variants, imp_method='deeplift')[0]
    
    viz_dict = OrderedDict(flatten_list([[
                        (f"{task} Pred", pred['pred'][task]),
                        (f"{task} ref - alt pred", ref_pred['pred'][task] - pred['pred'][task]),
                        (f"{task} Imp profile", pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        (f"{task} ref - alt Imp", ref_pred['imp_score'][f"{task}/weighted"] * ref_pred['seq'] - pred['imp_score'][f"{task}/weighted"] * pred['seq']),
                        # (f"{task} Imp counts", sum(pred['grads'][task_idx]['counts'].values()) / 2 * seq),
                    ] for task_idx, task in enumerate(bpnet.tasks)]))

    viz_dict = filter_tracks(viz_dict, [420, 575])

    # Hard-code the range
    fmax = {'Imp profile': 0.265742, 'Pred': 34.442215, }
    fmin = {'Imp profile': -0.16751188, 'Pred': 0.07202636}
    #fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}
    #fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in bpnet.tasks])
    #        for feature in ['Imp profile', 'Pred']}


    ylim = []
    for k in viz_dict:
        f = k.split(" ", 1)[1]
        if "Imp" in f:
            ylim.append((fmin['Imp profile'], fmax['Imp profile']))
        else:
            ylim.append((-fmax['Pred'], fmax['Pred']))

    fig = plot_tracks(viz_dict,
                      #seqlets=shifted_seqlets,
                      title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                      fig_height_per_track=0.5,
                      rotate_y=0,
                      fig_width=get_figsize(frac=1)[0],
                      ylim=ylim,
                      legend=False)
    return fig
In [51]:
plot_region_diff(interval, [Variant('chr17', 35503550 +420+80 + 1, 'ATGCATAACAA', 'GTTCGCTCGTG')]);
In [52]:
plot_region_diff(interval, [Variant('chr17', 35503550 +420+98 + 1, 'TGAT', 'ATGT')]);

Figure out which variant to use

In [22]:
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
In [40]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
In [24]:
dfi = load_instances(modisco_dir  / 'instances.parq', motifs=motifs, dedup=True)
number of deduplicatd instances: 1594551 (35.14740608536545%)
In [42]:
from basepair.modisco.pattern_instances import (multiple_load_instances, load_instances, filter_nonoverlapping_intervals, 
                                                plot_coocurence_matrix, align_instance_center, dfi2seqlets, annotate_profile)
In [ ]:
# instance_parq_paths = {t: model_dir / f'deeplift/{t}/out/{imp_score}/instances.parq' 
#                        for t in tasks}

# dfi = multiple_load_instances(instance_parq_paths, motifs)
In [25]:
dfi.head()
Out[25]:
pattern example_idx pattern_start pattern_end strand pattern_len pattern_center match_weighted match_weighted_p match_weighted_cat match_max match_max_task imp_weighted imp_weighted_p imp_weighted_cat imp_max imp_max_task seq_match seq_match_p seq_match_cat match/Klf4 match/Nanog match/Oct4 match/Sox2 imp/Klf4 imp/Nanog imp/Oct4 imp/Sox2 example_chrom example_start example_end example_strand example_interval_from_task pattern_short pattern_name pattern_start_abs pattern_end_abs
0 metacluster_0/pattern_0 0 104 119 + 15 111 0.2469 0.0010 low 0.4484 Oct4 0.0365 NaN NaN 0.0524 Oct4 3.9825 0.0085 low 0.2654 -0.3593 0.4484 0.3725 0.0235 0.0246 0.0524 0.0313 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482648 143482663
1 metacluster_0/pattern_0 0 263 278 - 15 270 0.2518 0.0011 low 0.2893 Oct4 0.0217 NaN NaN 0.0342 Nanog 1.1066 0.0007 low 0.1618 0.2170 0.2893 0.2803 0.0143 0.0342 0.0200 0.0203 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482807 143482822
2 metacluster_0/pattern_0 0 282 297 + 15 289 0.3466 0.0069 low 0.3870 Sox2 0.0272 NaN NaN 0.0296 Sox2 5.0100 0.0188 low 0.2257 0.3454 0.3723 0.3870 0.0233 0.0258 0.0279 0.0296 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482826 143482841
3 metacluster_0/pattern_0 0 445 460 + 15 452 0.2399 0.0008 low 0.2469 Sox2 0.1852 0.0008 high 0.4145 Nanog 1.7775 0.0012 low 0.2458 0.2408 0.2314 0.2469 0.1234 0.4145 0.0949 0.1896 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482989 143483004
4 metacluster_0/pattern_0 0 475 490 - 15 482 0.2903 0.0025 low 0.3322 Oct4 0.2889 0.0179 high 0.5196 Nanog 3.3390 0.0057 low 0.2092 0.2681 0.3322 0.2991 0.1613 0.5196 0.2049 0.3241 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143483019 143483034
In [ ]:
# get all instances in this region
dfi_subset = dfi.query('example_chrom == "chr17" & pattern_start_abs > 35503550 + 420 & pattern_end_abs < 35503550 + 575')
In [15]:
# Subset the motifs
dfi_subset = dfi.query('match_weighted_p > .2').query('imp_weighted_p > .01')
In [27]:
# get all instances in this region
dfi_subset = dfi.query('example_chrom == "chr17" & pattern_start_abs > 35503550 + 420 & pattern_end_abs < 35503550 + 575')
In [31]:
dfi_subset['rel_center'] = dfi_subset['pattern_center'] + dfi_subset['example_start']- (35503550 +420)
In [36]:
dfi_subset['pattern_width'] = dfi_subset['pattern_end'] - dfi_subset['pattern_start']
In [37]:
dfi_subset[['pattern_name', 'strand', 'rel_center', 'pattern_width']]
Out[37]:
pattern_name strand rel_center pattern_width
26 Oct4-Sox2 + 86 15
760152 Sox2 - 91 22
5806380 Klf4 + 13 10
5806381 Klf4 + 137 10
12636696 Nanog - 60 9
12636697 Nanog - 101 9
12636698 Nanog - 127 9
In [38]:
oct_sox_row = dfi_subset.loc[26]
In [39]:
oct_sox_row
Out[39]:
pattern          metacluster_0/pattern_0
example_idx                            2
pattern_start                        498
                          ...           
rel_center                            86
pattern_widtn                         15
pattern_width                         15
Name: 26, Length: 40, dtype: object
In [46]:
 
In [42]:
trimmed_interval = pybedtools.create_interval_from_list(['chr17', 35503550 +420, 35503550 +575])
In [48]:
seq = extract_seq(trimmed_interval, None, bpnet.fasta_file)
In [52]:
# Oct4-Sox2
seq[80:91]
Out[52]:
'ATGCATAACAA'
In [53]:
# Nanog
seq[98:102]
Out[53]:
'TGAT'
In [59]:
import random
random_os = ''.join(random.choices("ACGT", k=11))
random_nanog = ''.join(random.choices("ACGT", k=4))
In [60]:
random_os
Out[60]:
'GTTCGCTCGTG'
In [64]:
random_nanog
Out[64]:
'ATGT'