Goal

  • investigate the effect of motif perturbations

Tasks

  • [x] find all high-confidence instances in the peaks
  • [x] systematically perturb all the instances and visualize the results
In [167]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
In [168]:
paper_config()
In [4]:
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
# output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
In [56]:
create_tf_session(0)
Out[56]:
<tensorflow.python.client.session.Session at 0x7fb4c520d978>
In [153]:
patterns = read_pkl(modisco_dir / "patterns.pkl")  # aligned patterns
In [71]:
bpnet = BPNet.from_mdir(model_dir)
In [1115]:
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.utils import flatten
In [1116]:
list(motifs)
Out[1116]:
['Oct4-Sox2', 'Sox2', 'Nanog', 'Klf4']
In [111]:
dfi = load_instances(modisco_dir / 'instances.parq', motifs=motifs, dedup=False)
In [112]:
dfi = filter_nonoverlapping_intervals(dfi)
In [155]:
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
In [156]:
# Add aligned instances
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
dfi = align_instance_center(dfi, orig_patterns, patterns, trim_frac=0.08)
TF-MoDISco is using the TensorFlow backend.
In [672]:
from basepair.cli.imp_score import ImpScoreFile
imp_scores = ImpScoreFile.from_modisco_dir(modisco_dir)
In [181]:
profiles = imp_scores.get_profiles()
In [258]:
tasks = patterns[0].tasks()
In [259]:
tasks
Out[259]:
['Klf4', 'Nanog', 'Oct4', 'Sox2']

Find all high-confidence instances in the peaks

  • question: how many instances are there?
    • how long would it take to write out everything?
    • what will be the dataset size?
    • shall we rather do it on the fly
In [113]:
np.sum((dfi.imp_weighted_cat == 'high') & (dfi.match_weighted_cat != 'low'))
Out[113]:
36937
In [114]:
(dfi[(dfi.pattern_center > 400) & (dfi.pattern_center < 600)]
                    .query('match_weighted_p > 0.2')
                    .query('imp_weighted_p > 0')).shape
Out[114]:
(82707, 37)
In [817]:
dfi_subset= (dfi.query('match_weighted_p > 0.2')
             .query('imp_weighted_p > 0'))
dfi_subset['row_idx'] = np.arange(len(dfi_subset)).astype(int)
In [338]:
dfi_subset.shape
Out[338]:
(108190, 39)
In [117]:
np.mean(dfi_subset.groupby('example_idx').size() > 1)
Out[117]:
0.6317485432022537
In [118]:
dfi_subset.groupby('example_idx').size().plot.hist(10)
Out[118]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb712983a20>
In [119]:
dfi.head()
Out[119]:
pattern example_idx pattern_start pattern_end strand pattern_len pattern_center match_weighted match_weighted_p match_weighted_cat match_max match_max_task imp_weighted imp_weighted_p imp_weighted_cat imp_max imp_max_task seq_match seq_match_p seq_match_cat match/Klf4 match/Nanog match/Oct4 match/Sox2 imp/Klf4 imp/Nanog imp/Oct4 imp/Sox2 example_chrom example_start example_end example_strand example_interval_from_task pattern_short pattern_name pattern_start_abs pattern_end_abs
0 metacluster_0/pattern_0 0 104 119 + 15 111 0.2469 0.0010 low 0.4484 Oct4 0.0365 NaN NaN 0.0524 Oct4 3.9825 0.0085 low 0.2654 -0.3593 0.4484 0.3725 0.0235 0.0246 0.0524 0.0313 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482648 143482663
1 metacluster_0/pattern_0 0 263 278 - 15 270 0.2518 0.0011 low 0.2893 Oct4 0.0217 NaN NaN 0.0342 Nanog 1.1066 0.0007 low 0.1618 0.2170 0.2893 0.2803 0.0143 0.0342 0.0200 0.0203 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482807 143482822
2 metacluster_0/pattern_0 0 282 297 + 15 289 0.3466 0.0069 low 0.3870 Sox2 0.0272 NaN NaN 0.0296 Sox2 5.0100 0.0188 low 0.2257 0.3454 0.3723 0.3870 0.0233 0.0258 0.0279 0.0296 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482826 143482841
3 metacluster_0/pattern_0 0 445 460 + 15 452 0.2399 0.0008 low 0.2469 Sox2 0.1852 0.0008 high 0.4145 Nanog 1.7775 0.0012 low 0.2458 0.2408 0.2314 0.2469 0.1234 0.4145 0.0949 0.1896 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143482989 143483004
4 metacluster_0/pattern_0 0 475 490 - 15 482 0.2903 0.0025 low 0.3322 Oct4 0.2889 0.0179 high 0.5196 Nanog 3.3390 0.0057 low 0.2092 0.2681 0.3322 0.2991 0.1613 0.5196 0.2049 0.3241 chrX 143482544 143483544 * Oct4 m0_p0 Oct4-Sox2 143483019 143483034

Questions

  • [x] add predictions to the reference importance score implementation
  • should we store the output brute-force?
    • 100k preds 1 seq * 4 imp
    • no need to store them. Recomputing is pretty fast
In [945]:
from kipoi.data import Dataset
from concise.preprocessing import encodeDNA
import random


def random_seq_onehot(l):
    """Generate random sequence one-hot-encoded
    
    Args:
      l: sequence length
    """
    return encodeDNA([''.join(random.choices("ACGT", k=int(l)))])[0]

class PerturbDataset(Dataset):
    def __init__(self, dfi, seqs):
        self.dfi = dfi
        self.seqs = seqs
        
    def __len__(self):
        return len(self.dfi)
    
    def __getitem__(self, idx):
        inst = self.dfi.iloc[idx]
        assert inst.row_idx == idx
        ref_seq = self.seqs[inst.example_idx]
        # generate the alternative sequence
        alt_seq = ref_seq.copy()
        alt_seq[int(inst.pattern_start):int(inst.pattern_end)] = random_seq_onehot(inst.pattern_end - inst.pattern_start)
        
        return {"inputs": alt_seq,
                "metadata": {"example_idx": inst.example_idx,
                             "pattern": inst.pattern,
                             "pattern_start": inst.pattern_start,
                             "pattern_end": inst.pattern_end,
                            }}
    
class DoublePerturbDatasetSeq(Dataset):
    def __init__(self, dfab, seqs):
        """Pertub pairs of motifs
        
        Args:
          dfab: motif pair data-frame
          seqs: original sequences
        """
        self.dfab = dfab
        self.seqs = seqs
        
    def __len__(self):
        return len(self.dfab)
    
    def __getitem__(self, idx):
        inst = self.dfab.iloc[idx]
        ref_seq = self.seqs[inst.example_idx]
        # generate the alternative sequence
        alt_seq = ref_seq.copy()
        alt_seq[int(inst.pattern_start_x):int(inst.pattern_end_x)] = random_seq_onehot(inst.pattern_end_x - inst.pattern_start_x)
        alt_seq[int(inst.pattern_start_y):int(inst.pattern_end_y)] = random_seq_onehot(inst.pattern_end_y - inst.pattern_start_y)
        
        return alt_seq
In [911]:
dfab_pairs['Nanog<>Nanog'].shape
Out[911]:
(19552, 114)
In [909]:
dfab_pairs['Nanog<>Nanog'].head()
Out[909]:
pattern_x example_idx pattern_start_x pattern_end_x strand_x pattern_len_x pattern_center_x match_weighted_x match_weighted_p_x match_weighted_cat_x match_max_x match_max_task_x imp_weighted_x imp_weighted_p_x imp_weighted_cat_x imp_max_x imp_max_task_x seq_match_x seq_match_p_x seq_match_cat_x match/Klf4_x match/Nanog_x match/Oct4_x match/Sox2_x imp/Klf4_x imp/Nanog_x imp/Oct4_x imp/Sox2_x example_chrom_x example_start_x example_end_x example_strand_x example_interval_from_task_x pattern_short_x pattern_name_x pattern_start_abs_x pattern_end_abs_x pattern_center_aln_x pattern_strand_aln_x row_idx_x pattern_y pattern_start_y pattern_end_y strand_y pattern_len_y pattern_center_y match_weighted_y match_weighted_p_y match_weighted_cat_y match_max_y match_max_task_y imp_weighted_y imp_weighted_p_y imp_weighted_cat_y imp_max_y imp_max_task_y seq_match_y seq_match_p_y seq_match_cat_y match/Klf4_y match/Nanog_y match/Oct4_y match/Sox2_y imp/Klf4_y imp/Nanog_y imp/Oct4_y imp/Sox2_y example_chrom_y example_start_y example_end_y example_strand_y example_interval_from_task_y pattern_short_y pattern_name_y pattern_start_abs_y pattern_end_abs_y pattern_center_aln_y pattern_strand_aln_y row_idx_y center_diff center_diff_aln strand_combination xy_ref_obs_inside xy_ref_obs_outside xy_ref_pred_inside xy_ref_pred_outside xy_ref_pred_match xy_ref_imp_inside xy_ref_imp_outside xy_ref_impcount_inside xy_ref_impcount_outside xy_alt_pred_inside xy_alt_pred_outside xy_alt_pred_match xy_alt_imp_inside xy_alt_imp_outside xy_alt_impcount_inside xy_alt_impcount_outside yx_ref_obs_inside yx_ref_obs_outside yx_ref_pred_inside yx_ref_pred_outside yx_ref_pred_match yx_ref_imp_inside yx_ref_imp_outside yx_ref_impcount_inside yx_ref_impcount_outside yx_alt_pred_inside yx_alt_pred_outside yx_alt_pred_match yx_alt_imp_inside yx_alt_imp_outside yx_alt_impcount_inside yx_alt_impcount_outside
7 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 254 254 ++ 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 259.4509 416.9008 0.2940 0.7618 3.4483 0.2434 7.5768 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 306.0172 824.3239 0.1890 0.4310 3.3727 0.5010 7.8317
8 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 263 263 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 257.0172 419.3344 0.3261 0.4494 3.7607 0.1569 7.6633 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 276.7634 1043.7615 0.1898 0.3320 3.6129 0.4834 8.0046
9 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 303 302 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 99.6234 576.7283 0.3627 0.0778 4.1323 0.0702 7.7500 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 223.9338 1187.6025 0.1864 0.2474 4.0624 0.4810 8.0731
12 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 9 9 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 68.6485 1061.6925 0.1232 0.0384 3.7654 0.0757 8.2570 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 242.1011 1078.4238 0.2119 0.3414 3.6035 0.2307 8.2572
13 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 49 48 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 37.4505 1092.8906 0.2602 0.0115 3.7922 0.0751 8.2576 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 446.5078 965.0286 0.2930 0.6285 3.6813 0.2409 8.3133
In [121]:
len(dfi)
Out[121]:
2947036

Estimate the total output size

In [69]:
seqs = imp_scores.get_seq()
In [ ]:
imp_scores_contrib = imp_scores.get_contrib()
In [673]:
imp_scores_contrib_counts = imp_scores.get_contrib(pred_summary='count')
In [72]:
preds = bpnet.predict(seqs)
In [89]:
%time preds2 = bpnet.predict(seqs)
CPU times: user 29.4 s, sys: 9.68 s, total: 39.1 s
Wall time: 52 s
In [85]:
from kipoi.writers import HDF5BatchWriter
In [86]:
# Write the predictions to hdf5
HDF5BatchWriter.dump(model_dir / 'preds.h5', preds)
In [87]:
!du -sh {model_dir}/preds.h5
2.7G	/users/avsec/workspace/basepair/data/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/preds.h5
In [88]:
%time preds2 = HDF5Reader.load(model_dir / 'preds.h5')
CPU times: user 19 s, sys: 2.39 s, total: 21.4 s
Wall time: 20.6 s
In [91]:
preds2['Oct4'].shape
Out[91]:
(98428, 1000, 2)
In [92]:
preds['Oct4'].shape
Out[92]:
(98428, 1000, 2)
In [ ]:
imp_scores = bpnet.imp_score_all(seqs, method='deeplift', aggregate_strand=True)
In [123]:
alt_dataset = PerturbDataset(dfi_subset, seqs).load_all()
100%|██████████| 3381/3381 [01:30<00:00, 37.47it/s]
In [124]:
alt_seqs = alt_dataset['inputs']
In [125]:
alt_preds = bpnet.predict(alt_seqs)
In [127]:
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True)
In [564]:
alt_imp_scores_contrib = {k: v * alt_seqs for k,v in alt_imp_scores.items()}
In [139]:
alt_dataset['preds'] = alt_preds
In [140]:
alt_dataset['imp_scores'] = alt_imp_scores
In [923]:
%tqdm_restart
In [142]:
HDF5BatchWriter.dump(modisco_dir / 'perturb.motifs.h5', alt_dataset)
In [126]:
len(alt_seqs)
Out[126]:
108190
In [162]:
# Code for motif combinations

# setup config
from basepair.modisco.pattern_instances import construct_motif_pairs

pairs = []
for i in range(len(motifs)):
    for j in range(i, len(motifs)):
        pairs.append([ list(motifs)[i], list(motifs)[j], ])

comp_strand_compbination = {
    "++": "--",
    "--": "++",
    "-+": "-+",
    "+-": "+-"
}

strand_combinations = ["++", "--", "+-", "-+"]


def motif_pair_dfi(dfi_filtered, motif_pair):
    """Construct the matrix of motif pairs
    
    Args:
      dfi_filtered: dfi filtered to the desired property
      motif_pair: tuple of two pattern_name's 
    Returns:
      pd.DataFrame with columns from dfi_filtered with _x and _y suffix
    """
    dfa = dfi_filtered[dfi_filtered.pattern_name == motif_pair[0]]
    dfb = dfi_filtered[dfi_filtered.pattern_name == motif_pair[1]]

    dfab = pd.merge(dfa, dfb, on='example_idx', how='outer')
    dfab = dfab[~dfab[['pattern_x', 'pattern_y']].isnull().any(1)]

    dfab['center_diff'] = dfab.pattern_center_y - dfab.pattern_center_x
    dfab['center_diff_aln'] = dfab.pattern_center_aln_y - dfab.pattern_center_aln_x
    dfab['strand_combination'] = dfab.strand_x + dfab.strand_y
    # assure the right strand combination
    dfab[dfab.center_diff < 0]['strand_combination'] = dfab[dfab.center_diff < 0]['strand_combination'].map(comp_strand_compbination)

    if motif_pair[0] == motif_pair[1]:
        dfab['strand_combination'][dfab['strand_combination'] == "--"] = "++"
        dfab = dfab[dfab.center_diff > 0]
    else:
        dfab.center_diff = np.abs(dfab.center_diff)
        dfab.center_diff_aln = np.abs(dfab.center_diff_aln)
    dfab = dfab[dfab.center_diff_aln != 0]  # exclude perfect matches
    return dfab


def plot_spacing(dfab, 
                 alpha_scatter=0.01, 
                 y_feature='profile_counts', 
                 center_diff_variable='center_diff', 
                 figsize=(3.42519, 6.85038)):
    from basepair.stats import smooth_window_agg, smooth_lowess, smooth_gam
    
    motif_pair = (dfab.iloc[0].pattern_name_x, dfab.iloc[0].pattern_name_y)
    strand_combinations = dfab.strand_combination.unique()
    fig_profile, axes = plt.subplots(2*len(strand_combinations), 1, figsize=figsize, sharex=True, sharey='row')
    
    motif_pair_c = motif_pair
    axes[0].set_title("<>".join(motif_pair), fontsize=7)

    j = 0  # first column

    dftw_filt = dfab[(dfab.center_diff < 150)] #  & (dfab.imp_weighted_p.max(1) > 0.3)]
    for i, sc in enumerate(strand_combinations):
        if y_feature == 'profile_counts':
            y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[0]] + "/profile_counts_x"])
            y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[1]] + "/profile_counts_y"])
        elif y_feature == 'imp_weighted':
            y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_x'])
            y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_y'])
        else:
            raise ValueError(f"Unkown y_feature: {y_feature}")
            
        # y1 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[0]]
        # y2 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[1]]
        x = dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable]

        #dm,ym,confi = average_distance(x,y, window=5)
        dm1,ym1,confi1 = smooth_lowess(x,y1, frac=0.15)
        dm2,ym2,confi2 = smooth_lowess(x,y2, frac=0.15)
        #dm,ym, confi = smooth_gam(x,y, 140, 20)

        ax = axes[2*i]
        ax.hist(dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable], np.arange(10, 150, 1));
        if j == 0:
            ax.set_ylabel(sc)

        # second plot
        ax.set_xlim([0, 150])
        ax = axes[2*i+1]
        ax.scatter(x,y1, alpha=alpha_scatter, s=8)
        if confi1 is not None:
            ax.fill_between(dm1, confi1[:,0], confi1[:,1], alpha=0.2)
        ax.plot(dm1, ym1, linewidth=2, alpha=0.8)
        ax.scatter(x,y2, alpha=alpha_scatter, s=8)
        if confi2 is not None:
            ax.fill_between(dm2, confi2[:,0], confi2[:,1], alpha=0.2)
        ax.plot(dm2, ym2, linewidth=2, alpha=0.8)
        if j == 0:
            ax.set_ylabel(sc)
        ax.xaxis.set_minor_locator(plt.MultipleLocator(10))
        ax.xaxis.set_major_locator(plt.MultipleLocator(20))
        if j == 0:
            ax.set_ylabel(sc)
        if i == len(strand_combinations) - 1:
            ax.set_xlabel("Distance between motifs")
    fig_profile.subplots_adjust(wspace=0, hspace=0) 
    return fig_profile
In [163]:
dfi_subset.shape
Out[163]:
(108190, 39)
In [164]:
# original
dfi.pattern_name.value_counts()
Out[164]:
Nanog        1542247
Klf4          503291
Oct4-Sox2     490437
Sox2          411061
Name: pattern_name, dtype: int64
In [165]:
dfab = motif_pair_dfi(dfi_subset, ['Oct4-Sox2', 'Sox2'])
In [169]:
fig = plot_spacing(dfab, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [170]:
dfab.shape
Out[170]:
(3532, 80)

Task

  • stich together all the predictions
    • focus on the example idx

Single motif focused

  • reference / permuted counts (inside, outside)
  • reference / permuted importance (inside, outside)
In [172]:
alt_imp_scores['Oct4/weighted'].shape
Out[172]:
(108190, 1000, 4)
In [216]:
p = patterns[0]
In [219]:
p.profile['Klf4'].shape
Out[219]:
(200, 2)
In [220]:
# TODO - get the reference profile for the normal trimmed profile
In [221]:
from basepair.exp.chipnexus.simulate import profile_sim_metrics

Implement a function to go from seqlet -> ref / alt profile

In [224]:
## TOOD - get the reference seqlet profiles
In [248]:
pattern = 'metacluster_0/pattern_0'
In [252]:
motifs
Out[252]:
OrderedDict([('Oct4-Sox2', 'm0_p0'),
             ('Sox2', 'm0_p1'),
             ('Nanog', 'm2_p0'),
             ('Klf4', 'm1_p0')])
In [254]:
mr.tasks()
Out[254]:
['Klf4/weighted', 'Nanog/weighted', 'Oct4/weighted', 'Sox2/weighted']
In [ ]:
tasks = ['']
In [260]:
def get_reference_profile(mr, pattern, tasks, profile_width=70, trim_frac=0.08, seqlen=1000):
    seqlets_ref = mr._get_seqlets(pattern, trim_frac=trim_frac)
    seqlets_ref = resize_seqlets(seqlets_ref, profile_width, seqlen=seqlen)

    task = 'Oct4'
    out = {}
    for task in tasks:
        seqlet_profile_ref = extract_signal(profiles[task], seqlets_ref)
        avg_profile = seqlet_profile_ref.mean(axis=0)
        out[task] = avg_profile
        # metrics_ref = pd.DataFrame([profile_sim_metrics(avg_profile, cp) for cp in seqlet_profile_ref])
    return out
In [262]:
ref_profiles = {p: get_reference_profile(mr, longer_pattern(sn), tasks) for p,sn in motifs.items()}
In [263]:
plot_stranded_profile(ref_profiles['Oct4-Sox2']['Oct4'])
In [244]:
metrics_ref
Out[244]:
counts counts_frac max max_frac simmetric_kl
0 34.0 0.3882 2.0 0.9263 inf
1 81.0 0.9249 4.0 1.8525 inf
2 34.0 0.3882 2.0 0.9263 inf
... ... ... ... ... ...
154 14.0 0.1599 1.0 0.4631 inf
155 10.0 0.1142 1.0 0.4631 inf
156 15.0 0.1713 1.0 0.4631 inf

157 rows × 5 columns

TODO

  • [x] match the profile with the reference profile
  • [x] add also the actual reference counts
  • [x] generate the dataset where mutated_seqlet_idx == signal_seqlet_idx
  • [x] generate the dataset where all valid pairwise combinations are present
In [269]:
alt_imp_scores.keys()
Out[269]:
dict_keys(['Oct4/weighted', 'Sox2/weighted', 'Nanog/weighted', 'Klf4/weighted', 'Oct4/count', 'Sox2/count', 'Nanog/count', 'Klf4/count'])
In [284]:
from basepair.stats import symmetric_kl
In [270]:
symmetric_kl??
Object `symmetric_kl` not found.
In [ ]:
imp_scores
In [968]:
class PerturbDataset(Dataset):
    def __init__(self, dfi, seqs, preds, profiles, imp_scores, imp_scores_counts, 
                 alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib, 
                 ref_profiles,
                 profile_mapping):
        self.dfi = dfi
        self.seqs = seqs
        self.preds = preds
        self.profiles = profiles
        self.imp_scores = imp_scores
        self.imp_scores_counts = imp_scores_counts
        self.alt_dataset = alt_dataset
        self.alt_seqs = alt_seqs
        self.alt_preds = alt_preds
        self.alt_imp_scores_contrib = alt_imp_scores_contrib
        self.ref_profiles = ref_profiles
        self.profile_mapping = profile_mapping
        
    def __len__(self):
        return len(self.dfi)
    
    def get_change(self, mutated_seqlet_idx, signal_seqlet_idx):
        inst = self.dfi.iloc[signal_seqlet_idx]
        
        assert inst.row_idx == signal_seqlet_idx
        
        task = profile_mapping[inst.pattern_name]
        
        ref_profile = self.ref_profiles[inst.pattern_name][task]
        
        mutated_inst = self.dfi.iloc[mutated_seqlet_idx]
        assert inst.example_idx == mutated_inst.example_idx
        
        narrow_seqlet = Seqlet(inst.example_idx, inst.pattern_start, inst.pattern_end, 
                               name=inst.pattern, strand=inst.strand)
        wide_seqlet = narrow_seqlet.resize(70)

        # ref
        ref_preds = self.preds[task][inst.example_idx]  # all predictions
        ref_preds_seqlet = wide_seqlet.extract(self.preds[task])
        ref_preds_inside = ref_preds_seqlet.sum()
        ref_preds_outside = ref_preds.sum() - ref_preds_inside
        
        ref_obs = self.profiles[task][inst.example_idx]  # all predictions
        ref_obs_seqlet = wide_seqlet.extract(self.profiles[task])
        ref_obs_inside = ref_obs_seqlet.sum()
        ref_obs_outside = ref_obs.sum() - ref_obs_inside
        try:
            ref_preds_match = symmetric_kl(ref_preds_seqlet, ref_profile).mean()  # compare with the reference
        except:
            ref_preds_match = np.nan
        
        # ref imp
        ref_imp_scores = self.imp_scores[f"{task}"][inst.example_idx]
        ref_imp_scores_seqlet = narrow_seqlet.extract(self.imp_scores[f"{task}"])
        ref_imp_inside = ref_imp_scores_seqlet.sum()  # sum in the seqlet region
        ref_imp_outside = ref_imp_scores.sum() - ref_imp_inside  # total - seqlet

        # ref imp counts
        ref_imp_scores_c = self.imp_scores_counts[f"{task}"][inst.example_idx]
        ref_imp_scores_seqlet_c = narrow_seqlet.extract(self.imp_scores_counts[f"{task}"])
        ref_imp_inside_c = ref_imp_scores_seqlet_c.sum()  # sum in the seqlet region
        ref_imp_outside_c = ref_imp_scores_c.sum() - ref_imp_inside_c  # total - seqlet
        
        # alt
        narrow_seqlet.seqname = mutated_seqlet_idx  # change the sequence name
        wide_seqlet.seqname = mutated_seqlet_idx
        
        alt_preds = self.alt_preds[task][mutated_seqlet_idx]
        alt_preds_seqlet = wide_seqlet.extract(self.alt_preds[task])
        alt_preds_inside = alt_preds_seqlet.sum()  # sum in the seqlet region
        alt_preds_outside = alt_preds.sum() - alt_preds_inside  # total - seqlet
        try:
            alt_preds_match = symmetric_kl(alt_preds_seqlet, ref_profile).mean()
        except:
            alt_preds_match = np.nan
        
        alt_imp_scores = self.alt_imp_scores_contrib[f"{task}/weighted"][mutated_seqlet_idx]
        alt_imp_scores_seqlet = narrow_seqlet.extract(self.alt_imp_scores_contrib[f"{task}/weighted"])
        alt_imp_inside = alt_imp_scores_seqlet.sum()  # sum in the seqlet region
        alt_imp_outside = alt_imp_scores.sum() - alt_imp_inside  # total - seqlet
        
        alt_imp_scores_c = self.alt_imp_scores_contrib[f"{task}/count"][mutated_seqlet_idx]
        alt_imp_scores_seqlet_c = narrow_seqlet.extract(self.alt_imp_scores_contrib[f"{task}/count"])
        alt_imp_inside_c = alt_imp_scores_seqlet_c.sum()  # sum in the seqlet region
        alt_imp_outside_c = alt_imp_scores_c.sum() - alt_imp_inside_c  # total - seqlet
        
        return {
        "ref": {
                "obs": {
                    "inside": ref_obs_inside,
                    "outside": ref_obs_outside
                },
                "pred": {
                    "inside": ref_preds_inside,
                    "outside": ref_preds_outside,
                    "match": ref_preds_match
                },
                "imp": {
                    "inside": ref_imp_inside,
                    "outside": ref_imp_outside,
                },
                "impcount": {
                    "inside": ref_imp_inside_c,
                    "outside": ref_imp_outside_c,
                },
        },
            "alt": {
                "pred": {
                    "inside": alt_preds_inside,
                    "outside": alt_preds_outside,
                    "match": alt_preds_match
                },
                "imp": {
                    "inside": alt_imp_inside,
                    "outside": alt_imp_outside,
                },
                "impcount": {
                    "inside": alt_imp_inside_c,
                    "outside": alt_imp_outside_c,
                },                
            },
        }
    
class SingleMotifPerturbDataset(Dataset):
    def __init__(self, pdata):
        self.pdata = pdata
        
    def __len__(self):
        return len(self.pdata.dfi)
    
    def __getitem__(self, idx):
        return self.pdata.get_change(idx, idx)
    

class DoubleMotifPerturbDataset(Dataset):
    def __init__(self, dfab, dalt_preds, ref_profiles):
        """Use the predictions from DoublePerturbDatasetSeq
        
        Args:
          dfab: motif pair dataset
          dalt_preds: predictions from sequences obtained from DoublePerturbDatasetSeq
          ref_profiles: reference profiles for background seqlets
        """
        self.dfab = dfab
        self.ref_profiles = ref_profiles
        self.dalt_preds = dalt_preds
        
    def __len__(self):
        return len(self.dfab)
    
    def extract_features(self, idx, pattern_name, pattern_start, pattern_end, pattern, strand):
        task = profile_mapping[pattern_name]
        ref_profile = self.ref_profiles[pattern_name][task]
        narrow_seqlet = Seqlet(idx, pattern_start, pattern_end, 
                               name=pattern, strand=strand)
        wide_seqlet = narrow_seqlet.resize(70)
        ref_preds = self.dalt_preds[task][idx]  # all predictions
        ref_preds_seqlet = wide_seqlet.extract(self.dalt_preds[task])
        ref_preds_inside = ref_preds_seqlet.sum()
        ref_preds_total = ref_preds.sum()
        ref_preds_outside = ref_preds_total - ref_preds_inside
        try:
            ref_preds_match = symmetric_kl(ref_preds_seqlet, ref_profile).mean()  # compare with the reference
        except:
            ref_preds_match = np.nan        
        return {
            "pred": {
                "inside": ref_preds_inside,
                "outside": ref_preds_outside,
                "total": ref_preds_total,
                "match": ref_preds_match
            }}

    def __getitem__(self, idx):
        inst = self.dfab.iloc[idx]
        return {"dxy": {
            "x": self.extract_features(idx,
                                       pattern_name=inst.pattern_name_x,
                                       pattern_start=int(inst.pattern_start_x),
                                       pattern_end=int(inst.pattern_end_x),
                                       pattern=inst.pattern_x,
                                       strand=inst.strand_x),
            "y": self.extract_features(idx,
                                       pattern_name=inst.pattern_name_y,
                                       pattern_start=int(inst.pattern_start_y),
                                       pattern_end=int(inst.pattern_end_y),
                                       pattern=inst.pattern_y,
                                       strand=inst.strand_y)
        }}
In [925]:
dfab_pairs['Nanog<>Nanog'].head()
Out[925]:
pattern_x example_idx pattern_start_x pattern_end_x strand_x pattern_len_x pattern_center_x match_weighted_x match_weighted_p_x match_weighted_cat_x match_max_x match_max_task_x imp_weighted_x imp_weighted_p_x imp_weighted_cat_x imp_max_x imp_max_task_x seq_match_x seq_match_p_x seq_match_cat_x match/Klf4_x match/Nanog_x match/Oct4_x match/Sox2_x imp/Klf4_x imp/Nanog_x imp/Oct4_x imp/Sox2_x example_chrom_x example_start_x example_end_x example_strand_x example_interval_from_task_x pattern_short_x pattern_name_x pattern_start_abs_x pattern_end_abs_x pattern_center_aln_x pattern_strand_aln_x row_idx_x pattern_y pattern_start_y pattern_end_y strand_y pattern_len_y pattern_center_y match_weighted_y match_weighted_p_y match_weighted_cat_y match_max_y match_max_task_y imp_weighted_y imp_weighted_p_y imp_weighted_cat_y imp_max_y imp_max_task_y seq_match_y seq_match_p_y seq_match_cat_y match/Klf4_y match/Nanog_y match/Oct4_y match/Sox2_y imp/Klf4_y imp/Nanog_y imp/Oct4_y imp/Sox2_y example_chrom_y example_start_y example_end_y example_strand_y example_interval_from_task_y pattern_short_y pattern_name_y pattern_start_abs_y pattern_end_abs_y pattern_center_aln_y pattern_strand_aln_y row_idx_y center_diff center_diff_aln strand_combination xy_ref_obs_inside xy_ref_obs_outside xy_ref_pred_inside xy_ref_pred_outside xy_ref_pred_match xy_ref_imp_inside xy_ref_imp_outside xy_ref_impcount_inside xy_ref_impcount_outside xy_alt_pred_inside xy_alt_pred_outside xy_alt_pred_match xy_alt_imp_inside xy_alt_imp_outside xy_alt_impcount_inside xy_alt_impcount_outside yx_ref_obs_inside yx_ref_obs_outside yx_ref_pred_inside yx_ref_pred_outside yx_ref_pred_match yx_ref_imp_inside yx_ref_imp_outside yx_ref_impcount_inside yx_ref_impcount_outside yx_alt_pred_inside yx_alt_pred_outside yx_alt_pred_match yx_alt_imp_inside yx_alt_imp_outside yx_alt_impcount_inside yx_alt_impcount_outside
7 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 254 254 ++ 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 259.4509 416.9008 0.2940 0.7618 3.4483 0.2434 7.5768 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 306.0172 824.3239 0.1890 0.4310 3.3727 0.5010 7.8317
8 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 263 263 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 257.0172 419.3344 0.3261 0.4494 3.7607 0.1569 7.6633 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 276.7634 1043.7615 0.1898 0.3320 3.6129 0.4834 8.0046
9 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 303 302 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 99.6234 576.7283 0.3627 0.0778 4.1323 0.0702 7.7500 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 223.9338 1187.6025 0.1864 0.2474 4.0624 0.4810 8.0731
12 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 9 9 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 68.6485 1061.6925 0.1232 0.0384 3.7654 0.0757 8.2570 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 242.1011 1078.4238 0.2119 0.3414 3.6035 0.2307 8.2572
13 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 49 48 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 37.4505 1092.8906 0.2602 0.0115 3.7922 0.0751 8.2576 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 446.5078 965.0286 0.2930 0.6285 3.6813 0.2409 8.3133
In [972]:
%tqdm_restart

In [688]:
smpdata = PerturbDataset( dfi_subset, seqs, preds, profiles, imp_scores_contrib,imp_scores_contrib_counts, 
                 alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib, 
                 ref_profiles,
                 profile_mapping)
In [689]:
spdata = SingleMotifPerturbDataset(smpdata).load_all(num_workers=20)
100%|██████████| 3381/3381 [00:12<00:00, 280.68it/s]
In [690]:
dfsm = pd.DataFrame(flatten(spdata), index=dfi_subset.index)
dfsm = pd.concat([dfsm, dfi_subset], axis=1)
In [405]:
dfsm.head()
Out[405]:
ref_obs_inside ref_obs_outside ref_pred_inside ref_pred_outside ref_pred_match ref_imp_inside ref_imp_outside alt_pred_inside alt_pred_outside alt_pred_match alt_imp_inside alt_imp_outside diff_pred_inside diff_pred_outside diff_pred_match diff_imp_inside diff_imp_outside pattern example_idx pattern_start pattern_end strand pattern_len pattern_center match_weighted match_weighted_p match_weighted_cat match_max match_max_task imp_weighted imp_weighted_p imp_weighted_cat imp_max imp_max_task seq_match seq_match_p seq_match_cat match/Klf4 match/Nanog match/Oct4 match/Sox2 imp/Klf4 imp/Nanog imp/Oct4 imp/Sox2 example_chrom example_start example_end example_strand example_interval_from_task pattern_short pattern_name pattern_start_abs pattern_end_abs pattern_center_aln pattern_strand_aln
10 3898.0 12300.0 12766.9316 45467.4844 0.3456 0.8159 5.6606 5816.3760 24692.4727 0.4979 0.2534 -19.3449 -6950.5557 -20775.0117 0.1523 -0.5625 -25.0055 metacluster_0/pattern_0 1 437 452 + 15 444 0.5391 0.5105 medium 0.5627 Sox2 0.9221 0.8271 low 1.0504 Oct4 9.6558 0.4369 medium 0.5452 0.4714 0.5524 0.5627 0.5605 1.0090 1.0504 0.9147 chr3 122145063 122146063 * Oct4 m0_p0 Oct4-Sox2 122145500 122145515 448 +
11 4946.0 11252.0 20336.1445 37898.2695 0.3172 0.7206 5.7559 12745.0439 26239.5859 0.5107 -0.0966 -22.9252 -7591.1006 -11658.6836 0.1935 -0.8172 -28.6811 metacluster_0/pattern_0 1 458 473 + 15 465 0.5174 0.3763 medium 0.5613 Oct4 1.0389 0.9077 low 1.2170 Sox2 11.3139 0.7737 high 0.4683 0.4598 0.5613 0.5274 0.5500 1.1198 1.0884 1.2170 chr3 122145063 122146063 * Oct4 m0_p0 Oct4-Sox2 122145521 122145536 469 +
14 5233.0 10965.0 28666.8789 29567.5352 0.2252 1.0919 5.3846 5705.7559 8943.7070 0.3381 0.1306 -18.9671 -22961.1230 -20623.8281 0.1129 -0.9613 -24.3517 metacluster_0/pattern_0 1 499 514 + 15 506 0.5154 0.3673 medium 0.5390 Oct4 1.3206 0.9862 low 1.5825 Oct4 9.8925 0.4740 medium 0.4916 0.4731 0.5390 0.5267 0.6803 1.0429 1.5825 1.5485 chr3 122145063 122146063 * Oct4 m0_p0 Oct4-Sox2 122145562 122145577 510 +
16 4337.0 11861.0 27433.3633 30801.0508 0.2641 0.8260 5.6505 12769.5527 17427.2871 0.3377 0.2582 -18.4827 -14663.8105 -13373.7637 0.0736 -0.5678 -24.1332 metacluster_0/pattern_0 1 520 535 + 15 527 0.4897 0.2392 low 0.5008 Oct4 1.4250 0.9950 low 1.6570 Sox2 7.7218 0.1351 low 0.4555 0.4992 0.5008 0.4895 0.8025 1.2853 1.6062 1.6570 chr3 122145063 122146063 * Oct4 m0_p0 Oct4-Sox2 122145583 122145598 531 +
18 3800.0 12398.0 22574.7500 35659.6641 0.2272 0.9338 5.5427 5425.8335 14069.9590 0.3850 0.0627 -18.7416 -17148.9160 -21589.7051 0.1579 -0.8711 -24.2843 metacluster_0/pattern_0 1 541 556 + 15 548 0.4843 0.2165 low 0.5198 Oct4 1.2319 0.9746 low 1.4439 Oct4 9.6558 0.4369 medium 0.4183 0.4421 0.5198 0.5058 0.6828 1.0318 1.4439 1.4188 chr3 122145063 122146063 * Oct4 m0_p0 Oct4-Sox2 122145604 122145619 552 +
In [570]:
from basepair.config import test_chr
In [659]:
np.log10(1+dfsm[['ref_obs_inside', 'ref_pred_inside']]).plot.scatter("ref_obs_inside", "ref_pred_inside", alpha=0.05, s=1)
Out[659]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb7064840f0>
In [658]:
np.log10(1+dfsm[dfsm.example_chrom.isin(test_chr)][['ref_obs_inside', 'ref_pred_inside']]).plot.scatter("ref_obs_inside", "ref_pred_inside", alpha=0.05, s=1)
Out[658]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb7064840b8>
In [657]:
dfsm.plot.scatter("ref_imp_inside", "imp_weighted", alpha=0.1, s=1)
Out[657]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb6e3e03390>
In [656]:
plt.scatter(dfsm.imp_weighted, np.log10(dfsm.ref_pred_inside), alpha=0.1, s=1)
Out[656]:
<matplotlib.collections.PathCollection at 0x7fb710693320>
In [655]:
dfsm.plot.scatter("ref_pred_match", "diff_pred_match", alpha=0.1, s=1)
Out[655]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb712959898>
In [597]:
dfsm['log_diff_pred_inside'] = np.log10(1+np.abs(dfsm.diff_pred_inside)) * np.sign(dfsm.diff_pred_inside)
In [599]:
dfsm.plot.scatter("log_diff_pred_inside", "diff_imp_inside", alpha=0.05, s=1)
Out[599]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb70b591470>

Take into account motif pairs

In [691]:
class OtherMotifPerturbDataset(Dataset):
    def __init__(self, pdata, dfab):
        self.pdata = pdata
        self.dfab = dfab
        
    def __len__(self):
        return len(self.dfab)
    
    def __getitem__(self, idx):
        xidx = int(self.dfab.iloc[idx].row_idx_x)
        yidx = int(self.dfab.iloc[idx].row_idx_y)
        return {"xy": self.pdata.get_change(xidx, yidx),  # mutate x, measure y
                "yx": self.pdata.get_change(yidx, xidx)}  # mutate y, measure x
In [692]:
motif_pair = ['Nanog', 'Klf4']
In [693]:
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
    print(f"{i+1}/{len(pairs)}")
    dfab = motif_pair_dfi(dfi_subset, motif_pair)
    opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
    dfab_sm = pd.DataFrame(flatten(opdata), index=dfab.index)
    dfab_sm = pd.concat([dfab, dfab_sm], axis=1)
    dfab_pairs["<>".join(motif_pair)] = dfab_sm
0/10
100%|██████████| 63/63 [00:01<00:00, 62.25it/s]
1/10
100%|██████████| 111/111 [00:01<00:00, 80.76it/s]
2/10
100%|██████████| 412/412 [00:05<00:00, 79.60it/s]
3/10
100%|██████████| 178/178 [00:02<00:00, 76.10it/s]
4/10
100%|██████████| 52/52 [00:00<00:00, 68.91it/s]
5/10
100%|██████████| 341/341 [00:04<00:00, 82.29it/s]
6/10
100%|██████████| 155/155 [00:01<00:00, 83.80it/s]
7/10
100%|██████████| 611/611 [00:06<00:00, 88.65it/s]
8/10
100%|██████████| 697/697 [00:08<00:00, 85.02it/s]
9/10
100%|██████████| 209/209 [00:02<00:00, 72.29it/s]
In [941]:
dfab_pairs_bak = deepcopy(dfab_pairs)
In [1000]:
dfab_pairs = deepcopy(dfab_pairs_bak)
In [1002]:
%tqdm_restart
In [1003]:
# add also the double perturbations
for k, dfab in dfab_pairs.items():
    dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
    dalt_preds = bpnet.predict(dpdata_seqs)
    dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles).load_all(num_workers=20)
    df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
    dfab = pd.concat([dfab, df_dpdata], axis=1)
    dfab_pairs[k] = dfab  # override the new dfab
100%|██████████| 63/63 [00:00<00:00, 149.14it/s]
100%|██████████| 63/63 [00:00<00:00, 201.19it/s]
100%|██████████| 111/111 [00:00<00:00, 170.95it/s]
100%|██████████| 111/111 [00:00<00:00, 198.28it/s]
100%|██████████| 412/412 [00:02<00:00, 152.89it/s]
100%|██████████| 412/412 [00:01<00:00, 267.74it/s]
100%|██████████| 178/178 [00:01<00:00, 154.88it/s]
100%|██████████| 178/178 [00:00<00:00, 236.39it/s]
100%|██████████| 52/52 [00:00<00:00, 144.90it/s]
100%|██████████| 52/52 [00:00<00:00, 207.60it/s]
100%|██████████| 341/341 [00:02<00:00, 164.25it/s]
100%|██████████| 341/341 [00:01<00:00, 265.50it/s]
100%|██████████| 155/155 [00:01<00:00, 147.93it/s]
100%|██████████| 155/155 [00:00<00:00, 210.16it/s]
100%|██████████| 611/611 [00:04<00:00, 150.51it/s]
100%|██████████| 611/611 [00:02<00:00, 257.67it/s]
100%|██████████| 697/697 [00:04<00:00, 156.84it/s]
100%|██████████| 697/697 [00:02<00:00, 238.46it/s]
100%|██████████| 209/209 [00:01<00:00, 169.89it/s]
100%|██████████| 209/209 [00:00<00:00, 232.00it/s]
In [1004]:
# append A|dA
dfsm_prefixed_x = dfsm.copy()
dfsm_prefixed_y = dfsm.copy()
dfsm_prefixed_x.columns = ["dx_x_" + c for c in dfsm_prefixed_x.columns]
dfsm_prefixed_y.columns = ["dy_y_" + c for c in dfsm_prefixed_y.columns]
for k, dfab in dfab_pairs.items():
    dfab = pd.merge(dfab, dfsm_prefixed_x, how='left', left_on='row_idx_x', right_on='dx_x_row_idx')
    dfab = pd.merge(dfab, dfsm_prefixed_y, how='left', left_on='row_idx_y', right_on='dy_y_row_idx')
    dfab_pairs[k] = dfab  # override the new dfab
In [1005]:
# store the pairs
write_pkl(dfab_pairs, modisco_dir / 'dfab_pairs.pkl')
In [474]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [478]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [482]:
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
In [455]:
# Diff (X -> Y, Y -> X)
dfab_sm.head()
Out[455]:
pattern_x example_idx pattern_start_x pattern_end_x strand_x pattern_len_x pattern_center_x match_weighted_x match_weighted_p_x match_weighted_cat_x match_max_x match_max_task_x imp_weighted_x imp_weighted_p_x imp_weighted_cat_x imp_max_x imp_max_task_x seq_match_x seq_match_p_x seq_match_cat_x match/Klf4_x match/Nanog_x match/Oct4_x match/Sox2_x imp/Klf4_x imp/Nanog_x imp/Oct4_x imp/Sox2_x example_chrom_x example_start_x example_end_x example_strand_x example_interval_from_task_x pattern_short_x pattern_name_x pattern_start_abs_x pattern_end_abs_x pattern_center_aln_x pattern_strand_aln_x row_idx_x pattern_y pattern_start_y pattern_end_y strand_y pattern_len_y pattern_center_y match_weighted_y match_weighted_p_y match_weighted_cat_y match_max_y match_max_task_y imp_weighted_y imp_weighted_p_y imp_weighted_cat_y imp_max_y imp_max_task_y seq_match_y seq_match_p_y seq_match_cat_y match/Klf4_y match/Nanog_y match/Oct4_y match/Sox2_y imp/Klf4_y imp/Nanog_y imp/Oct4_y imp/Sox2_y example_chrom_y example_start_y example_end_y example_strand_y example_interval_from_task_y pattern_short_y pattern_name_y pattern_start_abs_y pattern_end_abs_y pattern_center_aln_y pattern_strand_aln_y row_idx_y center_diff center_diff_aln strand_combination xy_ref_obs_inside xy_ref_obs_outside xy_ref_pred_inside xy_ref_pred_outside xy_ref_pred_match xy_ref_imp_inside xy_ref_imp_outside xy_alt_pred_inside xy_alt_pred_outside xy_alt_pred_match xy_alt_imp_inside xy_alt_imp_outside xy_diff_pred_inside xy_diff_pred_outside xy_diff_pred_match xy_diff_imp_inside xy_diff_imp_outside yx_ref_obs_inside yx_ref_obs_outside yx_ref_pred_inside yx_ref_pred_outside yx_ref_pred_match yx_ref_imp_inside yx_ref_imp_outside yx_alt_pred_inside yx_alt_pred_outside yx_alt_pred_match yx_alt_imp_inside yx_alt_imp_outside yx_diff_pred_inside yx_diff_pred_outside yx_diff_pred_match yx_diff_imp_inside yx_diff_imp_outside
0 metacluster_0/pattern_0 1 437.0 452.0 + 15.0 444.0 0.5391 0.5105 medium 0.5627 Sox2 0.9221 0.8271 low 1.0504 Oct4 9.6558 0.4369 medium 0.5452 0.4714 0.5524 0.5627 0.5605 1.0090 1.0504 0.9147 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 448.0 + 0.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 33.0 29.0 ++ 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 2586.5879 4602.4053 0.2493 0.2469 -2.8443 -3585.9199 -5166.8760 -0.0125 -0.2896 -9.3520 3898.0 12300.0 12766.9316 45467.4844 0.3456 0.8159 5.6606 6903.5137 25602.2246 0.3308 -2.9269 -18.7954 -5863.4180 -19865.2598 -0.0148 -3.7428 -24.4560
1 metacluster_0/pattern_0 1 458.0 473.0 + 15.0 465.0 0.5174 0.3763 medium 0.5613 Oct4 1.0389 0.9077 low 1.2170 Sox2 11.3139 0.7737 high 0.4683 0.4598 0.5613 0.5274 0.5500 1.1198 1.0884 1.2170 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 469.0 + 1.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 12.0 8.0 ++ 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 4536.5723 7998.9043 0.2250 -0.1606 -4.4746 -1635.9355 -1770.3770 -0.0367 -0.6971 -10.9824 4946.0 11252.0 20336.1445 37898.2695 0.3172 0.7206 5.7559 10751.0908 21754.6484 0.3000 -3.5680 -18.1543 -9585.0537 -16143.6211 -0.0172 -4.2886 -23.9102
2 metacluster_0/pattern_0 1 499.0 514.0 + 15.0 506.0 0.5154 0.3673 medium 0.5390 Oct4 1.3206 0.9862 low 1.5825 Oct4 9.8925 0.4740 medium 0.4916 0.4731 0.5390 0.5267 0.6803 1.0429 1.5825 1.5485 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 510.0 + 2.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 29.0 33.0 ++ 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 1627.8390 3511.0293 0.2044 0.3352 -1.9466 -4544.6689 -6258.2520 -0.0574 -0.2014 -8.4543 5233.0 10965.0 28666.8789 29567.5352 0.2252 1.0919 5.3846 14991.4434 17514.2949 0.2531 -3.4670 -18.2552 -13675.4355 -12053.2402 0.0280 -4.5589 -23.6399
3 metacluster_0/pattern_0 1 520.0 535.0 + 15.0 527.0 0.4897 0.2392 low 0.5008 Oct4 1.4250 0.9950 low 1.6570 Sox2 7.7218 0.1351 low 0.4555 0.4992 0.5008 0.4895 0.8025 1.2853 1.6062 1.6570 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 531.0 + 3.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 50.0 54.0 ++ 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 3066.9443 4847.3604 0.2474 0.5003 -1.8956 -3105.5635 -4921.9209 -0.0144 -0.0363 -8.4033 4337.0 11861.0 27433.3633 30801.0508 0.2641 0.8260 5.6505 14265.9512 18239.7871 0.2324 -4.8920 -16.8303 -13167.4121 -12561.2637 -0.0318 -5.7180 -22.4808
4 metacluster_0/pattern_0 1 541.0 556.0 + 15.0 548.0 0.4843 0.2165 low 0.5198 Oct4 1.2319 0.9746 low 1.4439 Oct4 9.6558 0.4369 medium 0.4183 0.4421 0.5198 0.5058 0.6828 1.0318 1.4439 1.4188 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 552.0 + 4.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 71.0 75.0 ++ 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 2871.7312 4026.5842 0.2666 0.4514 -2.9302 -3300.7766 -5742.6973 0.0048 -0.0852 -9.4380 3800.0 12398.0 22574.7500 35659.6641 0.2272 0.9338 5.5427 12351.7324 20154.0059 0.1992 -3.0180 -18.7042 -10223.0176 -15505.6582 -0.0280 -3.9518 -24.2469
In [453]:
# reference (X->X)
dfsm.head()
Out[453]:
xy_ref_obs_inside xy_ref_obs_outside xy_ref_pred_inside xy_ref_pred_outside xy_ref_pred_match xy_ref_imp_inside xy_ref_imp_outside xy_alt_pred_inside xy_alt_pred_outside xy_alt_pred_match xy_alt_imp_inside xy_alt_imp_outside xy_diff_pred_inside xy_diff_pred_outside xy_diff_pred_match xy_diff_imp_inside xy_diff_imp_outside yx_ref_obs_inside yx_ref_obs_outside yx_ref_pred_inside yx_ref_pred_outside yx_ref_pred_match yx_ref_imp_inside yx_ref_imp_outside yx_alt_pred_inside yx_alt_pred_outside yx_alt_pred_match yx_alt_imp_inside yx_alt_imp_outside yx_diff_pred_inside yx_diff_pred_outside yx_diff_pred_match yx_diff_imp_inside yx_diff_imp_outside pattern_x example_idx pattern_start_x pattern_end_x strand_x pattern_len_x pattern_center_x match_weighted_x match_weighted_p_x match_weighted_cat_x match_max_x match_max_task_x imp_weighted_x imp_weighted_p_x imp_weighted_cat_x imp_max_x imp_max_task_x seq_match_x seq_match_p_x seq_match_cat_x match/Klf4_x match/Nanog_x match/Oct4_x match/Sox2_x imp/Klf4_x imp/Nanog_x imp/Oct4_x imp/Sox2_x example_chrom_x example_start_x example_end_x example_strand_x example_interval_from_task_x pattern_short_x pattern_name_x pattern_start_abs_x pattern_end_abs_x pattern_center_aln_x pattern_strand_aln_x row_idx_x pattern_y pattern_start_y pattern_end_y strand_y pattern_len_y pattern_center_y match_weighted_y match_weighted_p_y match_weighted_cat_y match_max_y match_max_task_y imp_weighted_y imp_weighted_p_y imp_weighted_cat_y imp_max_y imp_max_task_y seq_match_y seq_match_p_y seq_match_cat_y match/Klf4_y match/Nanog_y match/Oct4_y match/Sox2_y imp/Klf4_y imp/Nanog_y imp/Oct4_y imp/Sox2_y example_chrom_y example_start_y example_end_y example_strand_y example_interval_from_task_y pattern_short_y pattern_name_y pattern_start_abs_y pattern_end_abs_y pattern_center_aln_y pattern_strand_aln_y row_idx_y center_diff center_diff_aln strand_combination
0 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 2586.5879 4602.4053 0.2493 0.2469 -2.8443 -3585.9199 -5166.8760 -0.0125 -0.2896 -9.3520 3898.0 12300.0 12766.9316 45467.4844 0.3456 0.8159 5.6606 6903.5137 25602.2246 0.3308 -2.9269 -18.7954 -5863.4180 -19865.2598 -0.0148 -3.7428 -24.4560 metacluster_0/pattern_0 1 437.0 452.0 + 15.0 444.0 0.5391 0.5105 medium 0.5627 Sox2 0.9221 0.8271 low 1.0504 Oct4 9.6558 0.4369 medium 0.5452 0.4714 0.5524 0.5627 0.5605 1.0090 1.0504 0.9147 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 448.0 + 0.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 33.0 29.0 ++
1 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 4536.5723 7998.9043 0.2250 -0.1606 -4.4746 -1635.9355 -1770.3770 -0.0367 -0.6971 -10.9824 4946.0 11252.0 20336.1445 37898.2695 0.3172 0.7206 5.7559 10751.0908 21754.6484 0.3000 -3.5680 -18.1543 -9585.0537 -16143.6211 -0.0172 -4.2886 -23.9102 metacluster_0/pattern_0 1 458.0 473.0 + 15.0 465.0 0.5174 0.3763 medium 0.5613 Oct4 1.0389 0.9077 low 1.2170 Sox2 11.3139 0.7737 high 0.4683 0.4598 0.5613 0.5274 0.5500 1.1198 1.0884 1.2170 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 469.0 + 1.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 12.0 8.0 ++
2 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 1627.8390 3511.0293 0.2044 0.3352 -1.9466 -4544.6689 -6258.2520 -0.0574 -0.2014 -8.4543 5233.0 10965.0 28666.8789 29567.5352 0.2252 1.0919 5.3846 14991.4434 17514.2949 0.2531 -3.4670 -18.2552 -13675.4355 -12053.2402 0.0280 -4.5589 -23.6399 metacluster_0/pattern_0 1 499.0 514.0 + 15.0 506.0 0.5154 0.3673 medium 0.5390 Oct4 1.3206 0.9862 low 1.5825 Oct4 9.8925 0.4740 medium 0.4916 0.4731 0.5390 0.5267 0.6803 1.0429 1.5825 1.5485 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 510.0 + 2.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 29.0 33.0 ++
3 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 3066.9443 4847.3604 0.2474 0.5003 -1.8956 -3105.5635 -4921.9209 -0.0144 -0.0363 -8.4033 4337.0 11861.0 27433.3633 30801.0508 0.2641 0.8260 5.6505 14265.9512 18239.7871 0.2324 -4.8920 -16.8303 -13167.4121 -12561.2637 -0.0318 -5.7180 -22.4808 metacluster_0/pattern_0 1 520.0 535.0 + 15.0 527.0 0.4897 0.2392 low 0.5008 Oct4 1.4250 0.9950 low 1.6570 Sox2 7.7218 0.1351 low 0.4555 0.4992 0.5008 0.4895 0.8025 1.2853 1.6062 1.6570 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 531.0 + 3.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 50.0 54.0 ++
4 5397.0 8362.0 6172.5078 9769.2812 0.2618 0.5366 6.5078 2871.7312 4026.5842 0.2666 0.4514 -2.9302 -3300.7766 -5742.6973 0.0048 -0.0852 -9.4380 3800.0 12398.0 22574.7500 35659.6641 0.2272 0.9338 5.5427 12351.7324 20154.0059 0.1992 -3.0180 -18.7042 -10223.0176 -15505.6582 -0.0280 -3.9518 -24.2469 metacluster_0/pattern_0 1 541.0 556.0 + 15.0 548.0 0.4843 0.2165 low 0.5198 Oct4 1.2319 0.9746 low 1.4439 Oct4 9.6558 0.4369 medium 0.4183 0.4421 0.5198 0.5058 0.6828 1.0318 1.4439 1.4188 chr3 1.2215e+08 1.2215e+08 * Oct4 m0_p0 Oct4-Sox2 1.2215e+08 1.2215e+08 552.0 + 4.0 metacluster_2/pattern_0 473.0 482.0 + 9.0 477.0 0.4591 0.2761 low 0.5634 Nanog 0.6052 0.9958 low 0.7585 Nanog 3.8229 0.13 low 0.3008 0.5634 0.1316 0.2534 0.2079 0.7585 0.3207 0.349 chr3 1.2215e+08 1.2215e+08 * Oct4 m2_p0 Nanog 1.2215e+08 1.2215e+08 477.0 + 60200.0 71.0 75.0 ++
In [654]:
fig = plt.figure(figsize=get_figsize(.5))
dfab_pairs['Oct4-Sox2<>Sox2'].center_diff.plot.hist(30)
Out[654]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fb70b200c88>
In [ ]:
# Plot all pairs
In [611]:
mkdir -p {ddir}/figures/modisco/spacing/preturb
In [613]:
figures = Path(f"{ddir}/figures/modisco/spacing/preturb")
In [ ]:
# dfab_sm.
In [1074]:
from scipy.stats import wilcoxon

def plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=None, xl=[0, 2], pval=True):
    ax.scatter(x_alt / x_ref, 
               y_alt / y_ref, alpha=alpha, s=s, label=label)
    if pval:
        xpval = wilcoxon(x_ref, x_alt).pvalue
        ypval = wilcoxon(y_ref, y_alt).pvalue
        kwargs = dict(size="small", horizontalalignment='center')
        ax.text(1.8, 1, f"{xpval:.2g}", **kwargs)
        ax.text(1, 1.8, f"{xpval:.2g}", **kwargs)
    alpha = .5
    ax.plot(xl, xl, c='grey', alpha=alpha)
    ax.axvline(1, c='grey', alpha=alpha)
    ax.axhline(1, c='grey', alpha=alpha)
    ax.set_xlim(xl)
    ax.set_ylim(xl);

Add p-values

In [988]:
dfab_pairs['Nanog<>Nanog'].head()
Out[988]:
pattern_x example_idx pattern_start_x pattern_end_x strand_x pattern_len_x pattern_center_x match_weighted_x match_weighted_p_x match_weighted_cat_x match_max_x match_max_task_x imp_weighted_x imp_weighted_p_x imp_weighted_cat_x imp_max_x imp_max_task_x seq_match_x seq_match_p_x seq_match_cat_x match/Klf4_x match/Nanog_x match/Oct4_x match/Sox2_x imp/Klf4_x imp/Nanog_x imp/Oct4_x imp/Sox2_x example_chrom_x example_start_x example_end_x example_strand_x example_interval_from_task_x pattern_short_x pattern_name_x pattern_start_abs_x pattern_end_abs_x pattern_center_aln_x pattern_strand_aln_x row_idx_x pattern_y pattern_start_y pattern_end_y strand_y pattern_len_y pattern_center_y match_weighted_y match_weighted_p_y match_weighted_cat_y match_max_y match_max_task_y imp_weighted_y imp_weighted_p_y imp_weighted_cat_y imp_max_y imp_max_task_y seq_match_y seq_match_p_y seq_match_cat_y match/Klf4_y match/Nanog_y match/Oct4_y match/Sox2_y imp/Klf4_y imp/Nanog_y imp/Oct4_y imp/Sox2_y example_chrom_y example_start_y example_end_y example_strand_y example_interval_from_task_y pattern_short_y pattern_name_y pattern_start_abs_y pattern_end_abs_y pattern_center_aln_y pattern_strand_aln_y row_idx_y center_diff center_diff_aln strand_combination xy_ref_obs_inside xy_ref_obs_outside xy_ref_pred_inside xy_ref_pred_outside xy_ref_pred_match xy_ref_imp_inside xy_ref_imp_outside xy_ref_impcount_inside xy_ref_impcount_outside xy_alt_pred_inside xy_alt_pred_outside xy_alt_pred_match xy_alt_imp_inside xy_alt_imp_outside xy_alt_impcount_inside xy_alt_impcount_outside yx_ref_obs_inside yx_ref_obs_outside yx_ref_pred_inside yx_ref_pred_outside yx_ref_pred_match yx_ref_imp_inside yx_ref_imp_outside yx_ref_impcount_inside yx_ref_impcount_outside yx_alt_pred_inside yx_alt_pred_outside yx_alt_pred_match yx_alt_imp_inside yx_alt_imp_outside yx_alt_impcount_inside yx_alt_impcount_outside dxy_x_pred_inside dxy_x_pred_outside dxy_x_pred_total dxy_x_pred_match dxy_y_pred_inside dxy_y_pred_outside dxy_y_pred_total dxy_y_pred_match dx_x_ref_obs_inside dx_x_ref_obs_outside dx_x_ref_pred_inside dx_x_ref_pred_outside dx_x_ref_pred_match dx_x_ref_imp_inside dx_x_ref_imp_outside dx_x_ref_impcount_inside dx_x_ref_impcount_outside dx_x_alt_pred_inside dx_x_alt_pred_outside dx_x_alt_pred_match dx_x_alt_imp_inside dx_x_alt_imp_outside dx_x_alt_impcount_inside dx_x_alt_impcount_outside dx_x_pattern dx_x_example_idx dx_x_pattern_start dx_x_pattern_end dx_x_strand dx_x_pattern_len dx_x_pattern_center dx_x_match_weighted dx_x_match_weighted_p dx_x_match_weighted_cat dx_x_match_max dx_x_match_max_task dx_x_imp_weighted dx_x_imp_weighted_p dx_x_imp_weighted_cat dx_x_imp_max dx_x_imp_max_task dx_x_seq_match dx_x_seq_match_p dx_x_seq_match_cat dx_x_match/Klf4 dx_x_match/Nanog dx_x_match/Oct4 dx_x_match/Sox2 dx_x_imp/Klf4 dx_x_imp/Nanog dx_x_imp/Oct4 dx_x_imp/Sox2 dx_x_example_chrom dx_x_example_start dx_x_example_end dx_x_example_strand dx_x_example_interval_from_task dx_x_pattern_short dx_x_pattern_name dx_x_pattern_start_abs dx_x_pattern_end_abs dx_x_pattern_center_aln dx_x_pattern_strand_aln dx_x_row_idx dy_y_ref_obs_inside dy_y_ref_obs_outside dy_y_ref_pred_inside dy_y_ref_pred_outside dy_y_ref_pred_match dy_y_ref_imp_inside dy_y_ref_imp_outside dy_y_ref_impcount_inside dy_y_ref_impcount_outside dy_y_alt_pred_inside dy_y_alt_pred_outside dy_y_alt_pred_match dy_y_alt_imp_inside dy_y_alt_imp_outside dy_y_alt_impcount_inside dy_y_alt_impcount_outside dy_y_pattern dy_y_example_idx dy_y_pattern_start dy_y_pattern_end dy_y_strand dy_y_pattern_len dy_y_pattern_center dy_y_match_weighted dy_y_match_weighted_p dy_y_match_weighted_cat dy_y_match_max dy_y_match_max_task dy_y_imp_weighted dy_y_imp_weighted_p dy_y_imp_weighted_cat dy_y_imp_max dy_y_imp_max_task dy_y_seq_match dy_y_seq_match_p dy_y_seq_match_cat dy_y_match/Klf4 dy_y_match/Nanog dy_y_match/Oct4 dy_y_match/Sox2 dy_y_imp/Klf4 dy_y_imp/Nanog dy_y_imp/Oct4 dy_y_imp/Sox2 dy_y_example_chrom dy_y_example_start dy_y_example_end dy_y_example_strand dy_y_example_interval_from_task dy_y_pattern_short dy_y_pattern_name dy_y_pattern_start_abs dy_y_pattern_end_abs dy_y_pattern_center_aln dy_y_pattern_strand_aln dy_y_row_idx
0 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 254 254 ++ 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 259.4509 416.9008 0.2940 0.7618 3.4483 0.2434 7.5768 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 306.0172 824.3239 0.1890 0.4310 3.3727 0.5010 7.8317 65.9286 466.0667 531.9953 0.3524 44.3159 487.6794 531.9953 0.2297 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 45.3026 631.0491 0.2509 0.0305 4.1796 0.1119 7.7083 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 71.1106 1059.2305 0.1858 0.0161 3.7876 0.0697 8.2629 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207
1 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 263 263 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 257.0172 419.3344 0.3261 0.4494 3.7607 0.1569 7.6633 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 276.7634 1043.7615 0.1898 0.3320 3.6129 0.4834 8.0046 87.0752 590.9130 677.9881 0.3100 97.2186 580.7695 677.9881 0.3011 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 45.3026 631.0491 0.2509 0.0305 4.1796 0.1119 7.7083 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 236.0682 1084.4568 0.4628 0.0875 3.8574 0.0856 8.4024 metacluster_2/pattern_0 14 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208
2 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 303 302 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 99.6234 576.7283 0.3627 0.0778 4.1323 0.0702 7.7500 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 223.9338 1187.6025 0.1864 0.2474 4.0624 0.4810 8.0731 48.0065 735.9346 783.9411 0.2355 122.3456 661.5955 783.9411 0.3804 2426.0 6466.0 223.6235 1141.4980 0.1867 0.2558 4.0075 0.1976 3.9161 45.3026 631.0491 0.2509 0.0305 4.1796 0.1119 7.7083 metacluster_2/pattern_0 14 463 472 - 9 467 0.4649 0.2975 low 0.6524 Nanog 0.1970 0.4306 medium 0.2558 Nanog 7.4683 0.7938 high 0.0817 0.6524 -0.1516 0.2432 0.0492 0.2558 0.1079 0.0763 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73219764 73219773 468 - 60206 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 151.8345 1259.7019 0.4517 0.0847 4.2251 0.0721 8.4821 metacluster_2/pattern_0 14 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209
3 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208 9 9 ++ 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 68.6485 1061.6925 0.1232 0.0384 3.7654 0.0757 8.2570 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 242.1011 1078.4238 0.2119 0.3414 3.6035 0.2307 8.2572 38.5565 1055.2196 1093.7761 0.2217 36.6094 1057.1667 1093.7761 0.2554 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 71.1106 1059.2305 0.1858 0.0161 3.7876 0.0697 8.2629 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 1632.0 7260.0 406.1481 958.9734 0.3327 0.3393 3.9240 0.0568 4.0569 236.0682 1084.4568 0.4628 0.0875 3.8574 0.0856 8.4024 metacluster_2/pattern_0 14 726 735 - 9 730 0.5169 0.4961 medium 0.5441 Nanog 0.3132 0.8641 low 0.4321 Nanog 4.8914 0.2707 low 0.5222 0.5441 0.4026 0.4283 0.1153 0.4321 0.0166 0.0373 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220027 73220036 731 - 60208
4 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 metacluster_2/pattern_0 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209 49 48 -+ 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 37.4505 1092.8906 0.2602 0.0115 3.7922 0.0751 8.2576 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 446.5078 965.0286 0.2930 0.6285 3.6813 0.2409 8.3133 57.9874 1038.4796 1096.4669 0.2608 28.9500 1067.5170 1096.4669 0.3113 1696.0 7196.0 412.4007 952.7209 0.2843 0.5852 3.6780 0.0877 4.0261 71.1106 1059.2305 0.1858 0.0161 3.7876 0.0697 8.2629 metacluster_2/pattern_0 14 717 726 - 9 721 0.5582 0.6582 medium 0.5698 Nanog 0.4231 0.9623 low 0.5852 Nanog 6.3508 0.5510 medium 0.5499 0.5698 0.5033 0.5389 0.1437 0.5852 0.0262 0.0522 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220018 73220027 722 - 60207 663.0 8229.0 142.7500 1222.3716 0.3615 0.0595 4.2038 0.0265 4.0873 151.8345 1259.7019 0.4517 0.0847 4.2251 0.0721 8.4821 metacluster_2/pattern_0 14 766 775 + 9 770 0.4425 0.2255 low 0.4722 Nanog 0.0883 0.0004 high 0.1174 Nanog 1.2431 0.0059 low 0.3405 0.4722 0.3554 0.4507 0.0457 0.1174 0.0125 0.0156 chr5 73219301 73220301 * Oct4 m2_p0 Nanog 73220067 73220076 770 + 60209
In [1058]:
def plt_diag(xl, ax=None):
    if ax is None:
        ax = plt.gca()
    ax.set_xlim(xl)
    ax.set_ylim(xl)
    ax.plot(xl, xl, c='grey', alpha=0.5);    
In [1056]:
import warnings
warnings.filterwarnings("ignore")
In [1063]:
k = 'Oct4-Sox2<>Nanog'
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]

cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)

match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
             (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
            (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
In [1068]:
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))

variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
    dfab_sm = dfab_sma[eval(variable) == cat]
    plt.subplot(2, nc, 1)
    plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dx&dy")
    plt.ylabel("y_total | dx")
    plt_diag([0, 500])
    plt.title(k);

    plt.subplot(2, nc, 2)
    plt.scatter((dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total")
    plt.ylabel("y_total | dx")
    plt_diag([0, 1000])

    plt.subplot(2, nc, 3)
    plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total | dx")
    plt_diag([0, 1000])

    plt.subplot(2, nc, 4)
    plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside), s=2, alpha=0.2, label=cat)
    plt.xlabel("y_total | dy")
    plt.ylabel("y_total")
    plt_diag([0, 1000])

    plt.subplot(2, nc, nc+1)
    plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dx&dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 200])
    plt.title(k);

    plt.subplot(2, nc, nc+2)
    plt.scatter(dfab_sm.yx_ref_pred_inside , dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 400])

    plt.subplot(2, nc, nc+3)
    plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint | dx")
    plt_diag([0, 400])

    plt.subplot(2, nc, nc+4)
    plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.yx_ref_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy")
    plt.ylabel("y_footprint")
    plt_diag([0, 400])
    
    plt.subplot(2, nc, nc+5)
    plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.yx_ref_pred_inside, s=2, alpha=0.2, label=cat)
    plt.xlabel("y_footprint | dy&dx")
    plt.ylabel("y_footprint")
    plt_diag([0, 400])

plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
In [ ]:
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2)
plt.xlabel("y_total | dx&dy")
plt.ylabel("y_total | dx")
plt_diag([0, 500])
In [ ]:
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside, dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside, s=2, alpha=0.2)
In [1082]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_imp'
    dfab_sm = dfab_sma
    for j, ax in enumerate(axes[i]):
        if j == 0:
            if i == 0:
                ax.set_title("Corrected total counts")
            # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
        if j == 1:
            if i == 0:
                ax.set_title("Corrected footprint counts")
            # (A|dB - A|dA&dB)/(A - A|dA)
            x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
            x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
        if j == 2:
            if i == 0:
                ax.set_title("Total counts")
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
        if j == 3:
            if i == 0:
                ax.set_title("Footprint counts")
            x_alt = dfab_sm.xy_alt_pred_inside
            x_ref = dfab_sm.xy_ref_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside            
            # TODO
        elif j == 4:
            if i == 0:
                ax.set_title("Profile importance")
            x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
            x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
            y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
            y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
        elif j == 5:
            if i == 0:
                ax.set_title("Norm. profile importance")
            x_alt = dfab_sm.xy_alt_imp_inside
            x_ref = dfab_sm.xy_ref_imp_inside
            y_alt = dfab_sm.yx_alt_imp_inside
            y_ref = dfab_sm.yx_ref_imp_inside
#         elif j == 4:
#             if i == 0:
#                 ax.set_title("Count importance")
#             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
#                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
        elif j == 6:
            if i == 0:
                ax.set_title("Match of the footprint")
            x_alt = dfab_sm.xy_alt_pred_match
            x_ref = dfab_sm.xy_ref_pred_match
            y_alt = dfab_sm.yx_alt_pred_match
            y_ref = dfab_sm.yx_ref_pred_match
#         if j == 1:
#             plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, xl=[-5, 5])
#         else:
        plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
        #if j == nf - 1 and i == 0:
        #    ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            # ax.legend(labels=list(cat_strand.categories))
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)

Open questions

  • why is the sign inverted?
In [1069]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_imp'
    dfab_sm = dfab_sma
    for j, ax in enumerate(axes[i]):
        if j == 0:
            if i == 0:
                ax.set_title("Corrected total counts")
            # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA&dB)
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - dfab_sm.dxy_y_pred_total
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - dfab_sm.dxy_x_pred_total
        if j == 1:
            if i == 0:
                ax.set_title("Corrected footprint counts")
            # (A|dB - A|dA&dB)/(A - A|dA&dB)
            x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
            x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dxy_y_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dxy_x_pred_inside
        if j == 2:
            if i == 0:
                ax.set_title("Total counts")
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
        if j == 3:
            if i == 0:
                ax.set_title("Footprint counts")
            x_alt = dfab_sm.xy_alt_pred_inside
            x_ref = dfab_sm.xy_ref_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside            
            # TODO
        elif j == 4:
            if i == 0:
                ax.set_title("Profile importance")
            x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
            x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
            y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
            y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
        elif j == 5:
            if i == 0:
                ax.set_title("Norm. profile importance")
            x_alt = dfab_sm.xy_alt_imp_inside
            x_ref = dfab_sm.xy_ref_imp_inside
            y_alt = dfab_sm.yx_alt_imp_inside
            y_ref = dfab_sm.yx_ref_imp_inside
#         elif j == 4:
#             if i == 0:
#                 ax.set_title("Count importance")
#             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
#                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
        elif j == 6:
            if i == 0:
                ax.set_title("Match of the footprint")
            x_alt = dfab_sm.xy_alt_pred_match
            x_ref = dfab_sm.xy_ref_pred_match
            y_alt = dfab_sm.yx_alt_pred_match
            y_ref = dfab_sm.yx_ref_pred_match
#         if j == 1:
#             plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, xl=[-5, 5])
#         else:
        plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
        #if j == nf - 1 and i == 0:
        #    ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            # ax.legend(labels=list(cat_strand.categories))
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)

Color by importance

In [1075]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_imp'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[eval(variable) == cat]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)

Color by strand

In [1075]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_imp'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[eval(variable) == cat]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)

Color by match

In [1079]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .4
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
    imp_threshold = .2
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
    
    variable = 'cat_match'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[eval(variable) == cat]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)

Color by distance

In [1078]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
    
    cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_dist'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[eval(variable) == cat]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)

high-high importance, distance < 35, Color by strand

In [1080]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 35]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .2
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
    
    variable = 'cat_strand'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[(eval(variable) == cat) & (cat_imp == 'high-high')]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.short-range.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.short-range.color=imp.png', raster=True, transparent=False)

high-high importance, distance < 35, Color by match

In [1081]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[dfab_sma.center_diff < 35]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .4
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
    imp_threshold = .2
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
    
    variable = 'cat_match'
    for j, ax in enumerate(axes[i]):
        for k, cat in enumerate(eval(variable).categories):
            dfab_sm = dfab_sma[(eval(variable) == cat) & (cat_imp == 'high-high')]
            if j == 0:
                if i == 0:
                    ax.set_title("Corrected total counts")
                # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
            if j == 1:
                if i == 0:
                    ax.set_title("Corrected footprint counts")
                # (A|dB - A|dA&dB)/(A - A|dA)
                x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
                x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
            if j == 2:
                if i == 0:
                    ax.set_title("Total counts")
                x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
                x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
                y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
                y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
            if j == 3:
                if i == 0:
                    ax.set_title("Footprint counts")
                x_alt = dfab_sm.xy_alt_pred_inside
                x_ref = dfab_sm.xy_ref_pred_inside
                y_alt = dfab_sm.yx_alt_pred_inside
                y_ref = dfab_sm.yx_ref_pred_inside            
                # TODO
            elif j == 4:
                if i == 0:
                    ax.set_title("Profile importance")
                x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
                x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
                y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
                y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
            elif j == 5:
                if i == 0:
                    ax.set_title("Norm. profile importance")
                x_alt = dfab_sm.xy_alt_imp_inside
                x_ref = dfab_sm.xy_ref_imp_inside
                y_alt = dfab_sm.yx_alt_imp_inside
                y_ref = dfab_sm.yx_ref_imp_inside
    #         elif j == 4:
    #             if i == 0:
    #                 ax.set_title("Count importance")
    #             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
    #                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
            elif j == 6:
                if i == 0:
                    ax.set_title("Match of the footprint")
                x_alt = dfab_sm.xy_alt_pred_match
                x_ref = dfab_sm.xy_ref_pred_match
                y_alt = dfab_sm.yx_alt_pred_match
                y_ref = dfab_sm.yx_ref_pred_match

            plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
        if j == nf - 1 and i == 0:
            ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.short-range.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.short-range.color=imp.png', raster=True, transparent=False)

high-high importance, distance = (35, 70], Color by match

In [1083]:
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
    k = "<>".join(motif_pair)
    dfab_sma = dfab_pairs[k]
    dfab_sma = dfab_sma[(dfab_sma.center_diff > 35) & (dfab_sma.center_diff <= 70)]
    
    cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
    cat_strand = pd.Categorical(dfab_sma.strand_combination)

    match_threshold = .4
    cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" + 
                 (dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
    imp_threshold = .2
    cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" + 
                (dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
    
    cat = None
    for j, ax in enumerate(axes[i]):
        # for k, cat in enumerate(eval(variable).categories):
        dfab_sm = dfab_sma[(cat_imp == 'high-high')]
        if j == 0:
            if i == 0:
                ax.set_title("Corrected total counts")
            # (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
        if j == 1:
            if i == 0:
                ax.set_title("Corrected footprint counts")
            # (A|dB - A|dA&dB)/(A - A|dA)
            x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
            x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
        if j == 2:
            if i == 0:
                ax.set_title("Total counts")
            x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
            x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
            y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
            y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
        if j == 3:
            if i == 0:
                ax.set_title("Footprint counts")
            x_alt = dfab_sm.xy_alt_pred_inside
            x_ref = dfab_sm.xy_ref_pred_inside
            y_alt = dfab_sm.yx_alt_pred_inside
            y_ref = dfab_sm.yx_ref_pred_inside            
            # TODO
        elif j == 4:
            if i == 0:
                ax.set_title("Profile importance")
            x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
            x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
            y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
            y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
        elif j == 5:
            if i == 0:
                ax.set_title("Norm. profile importance")
            x_alt = dfab_sm.xy_alt_imp_inside
            x_ref = dfab_sm.xy_ref_imp_inside
            y_alt = dfab_sm.yx_alt_imp_inside
            y_ref = dfab_sm.yx_ref_imp_inside
#         elif j == 4:
#             if i == 0:
#                 ax.set_title("Count importance")
#             ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 ,   # / 2 seems to fix the scatterplots for an unknown reason
#                        dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
        elif j == 6:
            if i == 0:
                ax.set_title("Match of the footprint")
            x_alt = dfab_sm.xy_alt_pred_match
            x_ref = dfab_sm.xy_ref_pred_match
            y_alt = dfab_sm.yx_alt_pred_match
            y_ref = dfab_sm.yx_ref_pred_match

        plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=True)
        
        ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
        ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1])) 
            
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.mid-range.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.mid-range.png', raster=True, transparent=False)

Heatmaps

In [808]:
total_examples = len(dfi.example_idx.unique())
total_examples
Out[808]:
64434
In [814]:
# TODO - generalize this table to also have the diagonal in
In [822]:
motif_pair = ['Nanog', 'Nanog']
In [828]:
dfiab = dfab_pairs["<>".join(motif_pair)]
x_total = dfi_subset[dfi_subset.pattern_name == motif_pair[0]].shape[0]
xy_total = len(dfiab[dfiab.center_diff < 150].row_idx_x.unique())
In [829]:
x_total  # total number of instances of motif A
Out[829]:
47990
In [832]:
xy_total
Out[832]:
13321
In [837]:
dfab_pairs_filt = {k: v[v.center_diff < 150] for k,v in dfab_pairs.items()}
In [819]:
def plot_coocurence_matrix(dfi, dfiab_pairs, pairs, signif_threshold=1e-5, ax=None):
    """Test for motif co-occurence in example regions
    """
    if ax is None:
        ax = plt.gca()
    from sklearn.metrics import matthews_corrcoef
    from scipy.stats import fisher_exact
    import statsmodels as sm
    import seaborn as sns
    import matplotlib.pyplot as plt

    motifs = list({x for p in pairs for x in p})
    idx_to_motif = {i: m for i,m in enumerate(motifs)}
    
    o = np.zeros((len(motifs), len(motifs)))
    op = np.zeros((len(motifs), len(motifs)))

    totals = {m: dfi[dfi.pattern_name == m].shape[0]
              for m in motifs}
    
    for motif_pair in enumerate(pairs):
        i = idx_to_motif[motif_pair[0]]
        j = idx_to_motif[motif_pair[1]]
        dfiab = dfab_pairs["<>".join(motif_pair)]
        x_total = totals[motif_pair[0]]
        y_total = totals[motif_pair[1]]
        x_closeto_y_total = len(dfiab.row_idx_x.unique())
        y_closeto_x_total = len(dfiab.row_idx_y.unique())
        
        ct = [[a, b],
              [c, d]]
        ct = pd.crosstab(c[xn], c[yn])
        # add not-counted 0 values:
        ct.iloc[0, 0] += total_examples - len(c)
        t22 = sm.stats.contingency_tables.Table2x2(ct)
        o[i, j] = np.log2(t22.oddsratio)
        op[i, j] = t22.oddsratio_pvalue()
    signif = op < signif_threshold
    a = np.zeros_like(signif).astype(str)
    a[signif] = "*"
    a[~signif] = ""
    np.fill_diagonal(a, '')

    sns.heatmap(pd.DataFrame(o, columns=ndxs, index=ndxs), 
                annot=a, fmt="", vmin=-4, vmax=4, 
                cmap='RdBu_r', ax=ax)
    ax.set_title(f"Log2 odds-ratio. (*: p<{signif_threshold})")
In [1086]:
def compute_features(dfab_sm):
    return {"Corrected total counts": dict(
        x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total, # total counts | dB - dxy_x_pred_total
        x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside),
        y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total,
        y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
        ),
    "Corrected footprint counts": dict(
        # (A|dB - A|dA&dB)/(A - A|dA)
        x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside, # total counts | dB - dxy_x_pred_total
        x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside,
        y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside,
        y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
    ),
    "Total counts": dict(
        x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside),
        x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside),
        y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside),
        y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
    ),

    "Footprint counts": dict(
        x_alt = dfab_sm.xy_alt_pred_inside,
        x_ref = dfab_sm.xy_ref_pred_inside,
        y_alt = dfab_sm.yx_alt_pred_inside,
        y_ref = dfab_sm.yx_ref_pred_inside
    ),
    "Profile importance": dict(
        x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)),
        x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)),
        y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)),
        y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
    ),
    "Norm. profile importance": dict(
        x_alt = dfab_sm.xy_alt_imp_inside, 
        x_ref = dfab_sm.xy_ref_imp_inside, 
        y_alt = dfab_sm.yx_alt_imp_inside, 
        y_ref = dfab_sm.yx_ref_imp_inside
    ),
    "Match of the footprint": dict(
        x_alt = dfab_sm.xy_alt_pred_match,
        x_ref = dfab_sm.xy_ref_pred_match,
        y_alt = dfab_sm.yx_alt_pred_match,
        y_ref = dfab_sm.yx_ref_pred_match
    )}
In [1107]:
features = ["Corrected total counts",
    "Corrected footprint counts",
    "Total counts",
    "Footprint counts",
    "Profile importance",
    "Norm. profile importance",
    "Match of the footprint"]
In [1264]:
def plot_mutation_heatmap(dfab_pairs, pairs, motif_list, feature='Corrected footprint counts', signif_threshold=1e-5, ax=None, max_frac=2):
    if ax is None:
        ax = plt.gca()
    
    motifs = motif_list    
    motif_to_idx = {m: i for i,m in enumerate(motifs)}

    o = np.zeros((len(motifs), len(motifs)))
    op = np.zeros((len(motifs), len(motifs)))

    for motif_pair in pairs:
        i, j = motif_to_idx[motif_pair[0]], motif_to_idx[motif_pair[1]]
        dfab_sma = dfab_pairs["<>".join(motif_pair)]
        dfab_sma = dfab_sma[(dfab_sma.center_diff < 150)]
        features = compute_features(dfab_sma)[feature]

        o[i, j] = np.mean(features['y_alt'] / features['y_ref']) # x|dy
        o[j, i] = np.mean(features['x_alt'] / features['x_ref'])  # y|dx
        op[i,j] = wilcoxon(features['y_ref'], features['y_alt']).pvalue
        op[i,j] = wilcoxon(features['x_ref'], features['x_alt']).pvalue

    signif = op < signif_threshold
    a = np.zeros_like(signif).astype(str)
    a[signif] = "*"
    a[~signif] = ""

    sns.heatmap(pd.DataFrame(o, columns=["d" + x for x in motifs], index=motifs), 
                annot=a, fmt="", vmin=max_frac - 1, vmax=max_frac, 
                cmap='RdBu_r', ax=ax)
    ax.set_title(f"{feature} (alt / ref) (*: p<{signif_threshold})")    
In [1129]:
 
In [1156]:
norm_counts
Out[1156]:
pattern_name Klf4 Nanog Oct4-Sox2 Sox2
pattern_name
Klf4 30132 30132 30132 30132
Nanog 47990 47990 47990 47990
Oct4-Sox2 17417 17417 17417 17417
Sox2 12651 12651 12651 12651
In [1288]:
def co_occurence_matrix(dfi_subset, query_string=""):
    """Returns the fraction of times pattern x (row) overlaps pattern y (column)
    """
    from basepair.stats import norm_matrix
    total_number = dfi_subset.groupby(['pattern']).size()
    norm_counts = norm_matrix(total_number)

    # normalization: minimum number of counts
    total_number = dfi_subset.groupby(['pattern_name']).size()
    norm_counts = norm_matrix(total_number)

    # cross-product
    dfi_filt_crossp = pd.merge(dfi_subset[['pattern_name', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center', 'example_idx']].set_index('example_idx'), 
                               dfi_subset[['pattern_name', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center', 'example_idx']].set_index('example_idx'), 
                               how='outer', left_index=True, right_index=True).reset_index()
    # remove self-matches
    dfi_filt_crossp = dfi_filt_crossp.query('~((pattern_name_x == pattern_name_y) & (pattern_center_aln_x == pattern_center_aln_y) & (pattern_strand_aln_x == pattern_strand_aln_x))')

    if query_string:
        dfi_filt_crossp = dfi_filt_crossp.query(query_string)
    match_sizes = dfi_filt_crossp.groupby(['pattern_name_x', 'pattern_name_y']).size()
    count_matrix = match_sizes.unstack(fill_value=0)

    norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
    norm_count_matrix = norm_count_matrix.fillna(0)  # these examples didn't have any paired pattern
    return count_matrix, norm_count_matrix, norm_counts


def fisher_test_coc(random_coocurrence_counts, random_coocurrence, random_coocurrence_norm, coocurrence_counts, coocurrence, coocurrence_norm):
    import statsmodels as sm
    cols = list(coocurrence_norm.columns)
    n = len(random_coocurrence_counts)
    o = np.zeros((n,n))
    op = np.zeros((n,n))
    for i in range(n):
        for j in range(n):
            ct = [[random_coocurrence_norm.iloc[i,j] - random_coocurrence_counts.iloc[i,j], coocurrence_norm.iloc[i,j] - coocurrence_counts.iloc[i,j]],
                  [random_coocurrence_counts.iloc[i,j], coocurrence_counts.iloc[i,j]]]
            t22 = sm.stats.contingency_tables.Table2x2(np.array(ct))
            o[i, j] = t22.oddsratio
            op[i, j] = t22.oddsratio_pvalue()
    return pd.DataFrame(o, columns=cols, index=cols), pd.DataFrame(op, columns=cols, index=cols)


def coocurrence_plot(dfi_subset, motif_list, query_string="(abs(pattern_center_aln_x- pattern_center_aln_y) <= 150)", signif_threshold=1e-5):
    """Test for co-occurence
    
    Args:
      dfi_subset: desired subset of dfi
      motif_list: list of motifs used to order the heatmap
      query_string: string used with df_cross.query() to detering the valid motif pairs
      signif_threshold: significance threshold for Fisher's exact test
    """
    c_counts, c, c_norm = co_occurence_matrix(dfi_subset, query_string=query_string)
    
    # Generate the NULL
    dfi_subset_random  = dfi_subset.copy()
    np.random.seed(42)
    dfi_subset_random['example_idx'] = dfi_subset_random['example_idx'].sample(frac=1).values
    rc_counts, rc, rc_norm = co_occurence_matrix(dfi_subset_random, query_string=query_string)
    
    # test for significance
    o, op = fisher_test_coc(rc_counts, rc, rc_norm, c_counts, c, c_norm)
    
    # re-order
    o = o[motif_list].loc[motif_list]
    op = op[motif_list].loc[motif_list]
    
    signif = op < signif_threshold
    a = np.zeros_like(signif).astype(str)
    a[signif] = "*"
    a[~signif] = ""

    sns.heatmap(o, annot=a, fmt="", vmin=0, vmax=2, 
                cmap='RdBu_r', ax=ax)
    ax.set_title(f"odds-ratio (proximal / non-proximal) (*: p<{signif_threshold})");
In [1162]:
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
plot_coocurence_matrix(dfi_subset[(dfi_subset.pattern_center > 400) & (dfi_subset.pattern_center < 600)], total_examples, list(motifs), ax=ax)

Co-occurence

In [1293]:
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs))
ax.set_ylabel("Motif of interest")
ax.set_xlabel("Motif partner");

Effect of partner motif perturbation

In [1123]:
figsize_single = get_figsize(.5, aspect=1)
fig, axes = plt.subplots(1, len(features), figsize=(figsize_single[0]*len(features), figsize_single[0]))
for i, (feat, ax) in enumerate(zip(features, axes)):
    plot_mutation_heatmap(dfab_pairs, pairs, list(motifs), feat, ax=ax, max_frac=2)
plt.tight_layout()

TODO

  • [x] Re-order plots (TF pair X metric)
  • [X] Add deeplift profile * total counts
  • [x] Add deeplift total counts importances
  • [X] color the points according to the pairwise distance
    • < 35, 35-70, 70-150
    • strand (+-, --, ...)
    • strength of A, strength of B (high/low affinity)
  • [x] Add Wilcoxon test to the plot.
  • [x] compute A | dA & dB and add 2 new importance metrics to the plot (Including wilcoxon test)
    • add (A|dB - A|dA&dB ) / ( A - A|dA)
    • add (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA) plot
  • [x] Plot the heatmap of average effect (TFxTF, color=alt/ref, text=signif stars)
  • [~] generate 3 heatmaps
    • [X] co-occurence
    • [X] perturbation effect
    • [ ] 10bp periodicity
  • [ ] merge all into Figure 5

Open questions

  • [ ] can we ignore the strand and just aggregate all the results across 4 (or 3) strand combinations?
    • test: add them up and see if you are still getting qualitatively the same results
      • overplot or plot very close to each other (4 strands + aggregate):
        • simulations data (to see subtle differences)
        • histogram (as a chart step)
        • observed counts in the genome
  • [ ] is the simulated motif aligned with observed seqlet?
    • if not, then re-generate all the simulation data
  • [ ] how to test for significance for motifs of the same kind?

TODO

  • [x] why are the alt importance scores negative? now fixed
    • I was using the hypothetical contribs instead of contribs
  • [x] Dissect the contribution from flanks on counts / profile
  • [x] figure out the dependency graph
  • shall we weight the imp-scores with the total counts?

Analyze the dataset

  • for each motif individually, study the impact of mutations on the proximal and distal counts
    • how much of the local profile can be explained by the motif?
      • note: that point could be circular since the architecture is designed the 'local' way
  • for each motif pair, study the impact of the mutations on other motifs
    • measure the change in the following variables when perturbing motif A:
      • counts inside (A, B)x[ref, alt]
      • profile match (A,B) x [ref alt]
      • counts outside (A)x[ref, alt]
      • importance score (B)x[alt], (A,B)x[ref]
      • pairwise distance
      • strand orientation
  • [ ] compile the dataset such that you can extract the dataset for a probabilistic graphical model

Questions:

  • How much does the importance scores of B decrease when perturbing A?
  • How much do the profile counts of B decrease when perturbing A?

Stratification. How are the above questions influenced by

  • distance or orientation?
  • presence of other motifs (is there some redundancy)?
  • weak / strong affinity motifs of A?

Final goal:

  • paper figures for each motif pair
  • graph showing different connections between core motifs
    • visualize edges in networkx

Other plots

In [211]:
# Plot the profile instances
from basepair.modisco.pattern_instances import dfi2seqlets, annotate_profile
from basepair.modisco.results import Seqlet, resize_seqlets
from basepair.plot.profiles import extract_signal
from basepair.plot.profiles import  plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile

dfi_subset = (dfi.query('match_weighted_p > .2')
                 .query('imp_weighted_p > 0.0')
                 .query('pattern_name=="Oct4-Sox2"'))
seqlets = dfi2seqlets(dfi_subset)
seqlets = resize_seqlets(seqlets, 70, seqlen=1000)
seqlet_profiles = {k: extract_signal(v, seqlets) for k,v in profiles.items()}

multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));