# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
hv.extension('bokeh')
paper_config()
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
# output_dir = Path("/srv/www/kundaje/avsec/chipnexus/oct-sox-nanog-klf/models/n_dil_layers=9/modisco/all/profile")
create_tf_session(0)
patterns = read_pkl(modisco_dir / "patterns.pkl") # aligned patterns
bpnet = BPNet.from_mdir(model_dir)
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.utils import flatten
list(motifs)
dfi = load_instances(modisco_dir / 'instances.parq', motifs=motifs, dedup=False)
dfi = filter_nonoverlapping_intervals(dfi)
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
# Add aligned instances
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
dfi = align_instance_center(dfi, orig_patterns, patterns, trim_frac=0.08)
from basepair.cli.imp_score import ImpScoreFile
imp_scores = ImpScoreFile.from_modisco_dir(modisco_dir)
profiles = imp_scores.get_profiles()
tasks = patterns[0].tasks()
tasks
np.sum((dfi.imp_weighted_cat == 'high') & (dfi.match_weighted_cat != 'low'))
(dfi[(dfi.pattern_center > 400) & (dfi.pattern_center < 600)]
.query('match_weighted_p > 0.2')
.query('imp_weighted_p > 0')).shape
dfi_subset= (dfi.query('match_weighted_p > 0.2')
.query('imp_weighted_p > 0'))
dfi_subset['row_idx'] = np.arange(len(dfi_subset)).astype(int)
dfi_subset.shape
np.mean(dfi_subset.groupby('example_idx').size() > 1)
dfi_subset.groupby('example_idx').size().plot.hist(10)
dfi.head()
from kipoi.data import Dataset
from concise.preprocessing import encodeDNA
import random
def random_seq_onehot(l):
"""Generate random sequence one-hot-encoded
Args:
l: sequence length
"""
return encodeDNA([''.join(random.choices("ACGT", k=int(l)))])[0]
class PerturbDataset(Dataset):
def __init__(self, dfi, seqs):
self.dfi = dfi
self.seqs = seqs
def __len__(self):
return len(self.dfi)
def __getitem__(self, idx):
inst = self.dfi.iloc[idx]
assert inst.row_idx == idx
ref_seq = self.seqs[inst.example_idx]
# generate the alternative sequence
alt_seq = ref_seq.copy()
alt_seq[int(inst.pattern_start):int(inst.pattern_end)] = random_seq_onehot(inst.pattern_end - inst.pattern_start)
return {"inputs": alt_seq,
"metadata": {"example_idx": inst.example_idx,
"pattern": inst.pattern,
"pattern_start": inst.pattern_start,
"pattern_end": inst.pattern_end,
}}
class DoublePerturbDatasetSeq(Dataset):
def __init__(self, dfab, seqs):
"""Pertub pairs of motifs
Args:
dfab: motif pair data-frame
seqs: original sequences
"""
self.dfab = dfab
self.seqs = seqs
def __len__(self):
return len(self.dfab)
def __getitem__(self, idx):
inst = self.dfab.iloc[idx]
ref_seq = self.seqs[inst.example_idx]
# generate the alternative sequence
alt_seq = ref_seq.copy()
alt_seq[int(inst.pattern_start_x):int(inst.pattern_end_x)] = random_seq_onehot(inst.pattern_end_x - inst.pattern_start_x)
alt_seq[int(inst.pattern_start_y):int(inst.pattern_end_y)] = random_seq_onehot(inst.pattern_end_y - inst.pattern_start_y)
return alt_seq
dfab_pairs['Nanog<>Nanog'].shape
dfab_pairs['Nanog<>Nanog'].head()
len(dfi)
seqs = imp_scores.get_seq()
imp_scores_contrib = imp_scores.get_contrib()
imp_scores_contrib_counts = imp_scores.get_contrib(pred_summary='count')
preds = bpnet.predict(seqs)
%time preds2 = bpnet.predict(seqs)
from kipoi.writers import HDF5BatchWriter
# Write the predictions to hdf5
HDF5BatchWriter.dump(model_dir / 'preds.h5', preds)
!du -sh {model_dir}/preds.h5
%time preds2 = HDF5Reader.load(model_dir / 'preds.h5')
preds2['Oct4'].shape
preds['Oct4'].shape
imp_scores = bpnet.imp_score_all(seqs, method='deeplift', aggregate_strand=True)
alt_dataset = PerturbDataset(dfi_subset, seqs).load_all()
alt_seqs = alt_dataset['inputs']
alt_preds = bpnet.predict(alt_seqs)
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True)
alt_imp_scores_contrib = {k: v * alt_seqs for k,v in alt_imp_scores.items()}
alt_dataset['preds'] = alt_preds
alt_dataset['imp_scores'] = alt_imp_scores
%tqdm_restart
HDF5BatchWriter.dump(modisco_dir / 'perturb.motifs.h5', alt_dataset)
len(alt_seqs)
# Code for motif combinations
# setup config
from basepair.modisco.pattern_instances import construct_motif_pairs
pairs = []
for i in range(len(motifs)):
for j in range(i, len(motifs)):
pairs.append([ list(motifs)[i], list(motifs)[j], ])
comp_strand_compbination = {
"++": "--",
"--": "++",
"-+": "-+",
"+-": "+-"
}
strand_combinations = ["++", "--", "+-", "-+"]
def motif_pair_dfi(dfi_filtered, motif_pair):
"""Construct the matrix of motif pairs
Args:
dfi_filtered: dfi filtered to the desired property
motif_pair: tuple of two pattern_name's
Returns:
pd.DataFrame with columns from dfi_filtered with _x and _y suffix
"""
dfa = dfi_filtered[dfi_filtered.pattern_name == motif_pair[0]]
dfb = dfi_filtered[dfi_filtered.pattern_name == motif_pair[1]]
dfab = pd.merge(dfa, dfb, on='example_idx', how='outer')
dfab = dfab[~dfab[['pattern_x', 'pattern_y']].isnull().any(1)]
dfab['center_diff'] = dfab.pattern_center_y - dfab.pattern_center_x
dfab['center_diff_aln'] = dfab.pattern_center_aln_y - dfab.pattern_center_aln_x
dfab['strand_combination'] = dfab.strand_x + dfab.strand_y
# assure the right strand combination
dfab[dfab.center_diff < 0]['strand_combination'] = dfab[dfab.center_diff < 0]['strand_combination'].map(comp_strand_compbination)
if motif_pair[0] == motif_pair[1]:
dfab['strand_combination'][dfab['strand_combination'] == "--"] = "++"
dfab = dfab[dfab.center_diff > 0]
else:
dfab.center_diff = np.abs(dfab.center_diff)
dfab.center_diff_aln = np.abs(dfab.center_diff_aln)
dfab = dfab[dfab.center_diff_aln != 0] # exclude perfect matches
return dfab
def plot_spacing(dfab,
alpha_scatter=0.01,
y_feature='profile_counts',
center_diff_variable='center_diff',
figsize=(3.42519, 6.85038)):
from basepair.stats import smooth_window_agg, smooth_lowess, smooth_gam
motif_pair = (dfab.iloc[0].pattern_name_x, dfab.iloc[0].pattern_name_y)
strand_combinations = dfab.strand_combination.unique()
fig_profile, axes = plt.subplots(2*len(strand_combinations), 1, figsize=figsize, sharex=True, sharey='row')
motif_pair_c = motif_pair
axes[0].set_title("<>".join(motif_pair), fontsize=7)
j = 0 # first column
dftw_filt = dfab[(dfab.center_diff < 150)] # & (dfab.imp_weighted_p.max(1) > 0.3)]
for i, sc in enumerate(strand_combinations):
if y_feature == 'profile_counts':
y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[0]] + "/profile_counts_x"])
y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc][profile_mapping[motif_pair_c[1]] + "/profile_counts_y"])
elif y_feature == 'imp_weighted':
y1 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_x'])
y2 = np.log10(1+ dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted_y'])
else:
raise ValueError(f"Unkown y_feature: {y_feature}")
# y1 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[0]]
# y2 = dftw_filt[dftw_filt.strand_combination==sc]['imp_weighted'][motif_pair[1]]
x = dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable]
#dm,ym,confi = average_distance(x,y, window=5)
dm1,ym1,confi1 = smooth_lowess(x,y1, frac=0.15)
dm2,ym2,confi2 = smooth_lowess(x,y2, frac=0.15)
#dm,ym, confi = smooth_gam(x,y, 140, 20)
ax = axes[2*i]
ax.hist(dftw_filt[dftw_filt.strand_combination==sc][center_diff_variable], np.arange(10, 150, 1));
if j == 0:
ax.set_ylabel(sc)
# second plot
ax.set_xlim([0, 150])
ax = axes[2*i+1]
ax.scatter(x,y1, alpha=alpha_scatter, s=8)
if confi1 is not None:
ax.fill_between(dm1, confi1[:,0], confi1[:,1], alpha=0.2)
ax.plot(dm1, ym1, linewidth=2, alpha=0.8)
ax.scatter(x,y2, alpha=alpha_scatter, s=8)
if confi2 is not None:
ax.fill_between(dm2, confi2[:,0], confi2[:,1], alpha=0.2)
ax.plot(dm2, ym2, linewidth=2, alpha=0.8)
if j == 0:
ax.set_ylabel(sc)
ax.xaxis.set_minor_locator(plt.MultipleLocator(10))
ax.xaxis.set_major_locator(plt.MultipleLocator(20))
if j == 0:
ax.set_ylabel(sc)
if i == len(strand_combinations) - 1:
ax.set_xlabel("Distance between motifs")
fig_profile.subplots_adjust(wspace=0, hspace=0)
return fig_profile
dfi_subset.shape
# original
dfi.pattern_name.value_counts()
dfab = motif_pair_dfi(dfi_subset, ['Oct4-Sox2', 'Sox2'])
fig = plot_spacing(dfab, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
dfab.shape
alt_imp_scores['Oct4/weighted'].shape
p = patterns[0]
p.profile['Klf4'].shape
# TODO - get the reference profile for the normal trimmed profile
from basepair.exp.chipnexus.simulate import profile_sim_metrics
Implement a function to go from seqlet -> ref / alt profile
## TOOD - get the reference seqlet profiles
pattern = 'metacluster_0/pattern_0'
motifs
mr.tasks()
tasks = ['']
def get_reference_profile(mr, pattern, tasks, profile_width=70, trim_frac=0.08, seqlen=1000):
seqlets_ref = mr._get_seqlets(pattern, trim_frac=trim_frac)
seqlets_ref = resize_seqlets(seqlets_ref, profile_width, seqlen=seqlen)
task = 'Oct4'
out = {}
for task in tasks:
seqlet_profile_ref = extract_signal(profiles[task], seqlets_ref)
avg_profile = seqlet_profile_ref.mean(axis=0)
out[task] = avg_profile
# metrics_ref = pd.DataFrame([profile_sim_metrics(avg_profile, cp) for cp in seqlet_profile_ref])
return out
ref_profiles = {p: get_reference_profile(mr, longer_pattern(sn), tasks) for p,sn in motifs.items()}
plot_stranded_profile(ref_profiles['Oct4-Sox2']['Oct4'])
metrics_ref
alt_imp_scores.keys()
from basepair.stats import symmetric_kl
symmetric_kl??
imp_scores
class PerturbDataset(Dataset):
def __init__(self, dfi, seqs, preds, profiles, imp_scores, imp_scores_counts,
alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib,
ref_profiles,
profile_mapping):
self.dfi = dfi
self.seqs = seqs
self.preds = preds
self.profiles = profiles
self.imp_scores = imp_scores
self.imp_scores_counts = imp_scores_counts
self.alt_dataset = alt_dataset
self.alt_seqs = alt_seqs
self.alt_preds = alt_preds
self.alt_imp_scores_contrib = alt_imp_scores_contrib
self.ref_profiles = ref_profiles
self.profile_mapping = profile_mapping
def __len__(self):
return len(self.dfi)
def get_change(self, mutated_seqlet_idx, signal_seqlet_idx):
inst = self.dfi.iloc[signal_seqlet_idx]
assert inst.row_idx == signal_seqlet_idx
task = profile_mapping[inst.pattern_name]
ref_profile = self.ref_profiles[inst.pattern_name][task]
mutated_inst = self.dfi.iloc[mutated_seqlet_idx]
assert inst.example_idx == mutated_inst.example_idx
narrow_seqlet = Seqlet(inst.example_idx, inst.pattern_start, inst.pattern_end,
name=inst.pattern, strand=inst.strand)
wide_seqlet = narrow_seqlet.resize(70)
# ref
ref_preds = self.preds[task][inst.example_idx] # all predictions
ref_preds_seqlet = wide_seqlet.extract(self.preds[task])
ref_preds_inside = ref_preds_seqlet.sum()
ref_preds_outside = ref_preds.sum() - ref_preds_inside
ref_obs = self.profiles[task][inst.example_idx] # all predictions
ref_obs_seqlet = wide_seqlet.extract(self.profiles[task])
ref_obs_inside = ref_obs_seqlet.sum()
ref_obs_outside = ref_obs.sum() - ref_obs_inside
try:
ref_preds_match = symmetric_kl(ref_preds_seqlet, ref_profile).mean() # compare with the reference
except:
ref_preds_match = np.nan
# ref imp
ref_imp_scores = self.imp_scores[f"{task}"][inst.example_idx]
ref_imp_scores_seqlet = narrow_seqlet.extract(self.imp_scores[f"{task}"])
ref_imp_inside = ref_imp_scores_seqlet.sum() # sum in the seqlet region
ref_imp_outside = ref_imp_scores.sum() - ref_imp_inside # total - seqlet
# ref imp counts
ref_imp_scores_c = self.imp_scores_counts[f"{task}"][inst.example_idx]
ref_imp_scores_seqlet_c = narrow_seqlet.extract(self.imp_scores_counts[f"{task}"])
ref_imp_inside_c = ref_imp_scores_seqlet_c.sum() # sum in the seqlet region
ref_imp_outside_c = ref_imp_scores_c.sum() - ref_imp_inside_c # total - seqlet
# alt
narrow_seqlet.seqname = mutated_seqlet_idx # change the sequence name
wide_seqlet.seqname = mutated_seqlet_idx
alt_preds = self.alt_preds[task][mutated_seqlet_idx]
alt_preds_seqlet = wide_seqlet.extract(self.alt_preds[task])
alt_preds_inside = alt_preds_seqlet.sum() # sum in the seqlet region
alt_preds_outside = alt_preds.sum() - alt_preds_inside # total - seqlet
try:
alt_preds_match = symmetric_kl(alt_preds_seqlet, ref_profile).mean()
except:
alt_preds_match = np.nan
alt_imp_scores = self.alt_imp_scores_contrib[f"{task}/weighted"][mutated_seqlet_idx]
alt_imp_scores_seqlet = narrow_seqlet.extract(self.alt_imp_scores_contrib[f"{task}/weighted"])
alt_imp_inside = alt_imp_scores_seqlet.sum() # sum in the seqlet region
alt_imp_outside = alt_imp_scores.sum() - alt_imp_inside # total - seqlet
alt_imp_scores_c = self.alt_imp_scores_contrib[f"{task}/count"][mutated_seqlet_idx]
alt_imp_scores_seqlet_c = narrow_seqlet.extract(self.alt_imp_scores_contrib[f"{task}/count"])
alt_imp_inside_c = alt_imp_scores_seqlet_c.sum() # sum in the seqlet region
alt_imp_outside_c = alt_imp_scores_c.sum() - alt_imp_inside_c # total - seqlet
return {
"ref": {
"obs": {
"inside": ref_obs_inside,
"outside": ref_obs_outside
},
"pred": {
"inside": ref_preds_inside,
"outside": ref_preds_outside,
"match": ref_preds_match
},
"imp": {
"inside": ref_imp_inside,
"outside": ref_imp_outside,
},
"impcount": {
"inside": ref_imp_inside_c,
"outside": ref_imp_outside_c,
},
},
"alt": {
"pred": {
"inside": alt_preds_inside,
"outside": alt_preds_outside,
"match": alt_preds_match
},
"imp": {
"inside": alt_imp_inside,
"outside": alt_imp_outside,
},
"impcount": {
"inside": alt_imp_inside_c,
"outside": alt_imp_outside_c,
},
},
}
class SingleMotifPerturbDataset(Dataset):
def __init__(self, pdata):
self.pdata = pdata
def __len__(self):
return len(self.pdata.dfi)
def __getitem__(self, idx):
return self.pdata.get_change(idx, idx)
class DoubleMotifPerturbDataset(Dataset):
def __init__(self, dfab, dalt_preds, ref_profiles):
"""Use the predictions from DoublePerturbDatasetSeq
Args:
dfab: motif pair dataset
dalt_preds: predictions from sequences obtained from DoublePerturbDatasetSeq
ref_profiles: reference profiles for background seqlets
"""
self.dfab = dfab
self.ref_profiles = ref_profiles
self.dalt_preds = dalt_preds
def __len__(self):
return len(self.dfab)
def extract_features(self, idx, pattern_name, pattern_start, pattern_end, pattern, strand):
task = profile_mapping[pattern_name]
ref_profile = self.ref_profiles[pattern_name][task]
narrow_seqlet = Seqlet(idx, pattern_start, pattern_end,
name=pattern, strand=strand)
wide_seqlet = narrow_seqlet.resize(70)
ref_preds = self.dalt_preds[task][idx] # all predictions
ref_preds_seqlet = wide_seqlet.extract(self.dalt_preds[task])
ref_preds_inside = ref_preds_seqlet.sum()
ref_preds_total = ref_preds.sum()
ref_preds_outside = ref_preds_total - ref_preds_inside
try:
ref_preds_match = symmetric_kl(ref_preds_seqlet, ref_profile).mean() # compare with the reference
except:
ref_preds_match = np.nan
return {
"pred": {
"inside": ref_preds_inside,
"outside": ref_preds_outside,
"total": ref_preds_total,
"match": ref_preds_match
}}
def __getitem__(self, idx):
inst = self.dfab.iloc[idx]
return {"dxy": {
"x": self.extract_features(idx,
pattern_name=inst.pattern_name_x,
pattern_start=int(inst.pattern_start_x),
pattern_end=int(inst.pattern_end_x),
pattern=inst.pattern_x,
strand=inst.strand_x),
"y": self.extract_features(idx,
pattern_name=inst.pattern_name_y,
pattern_start=int(inst.pattern_start_y),
pattern_end=int(inst.pattern_end_y),
pattern=inst.pattern_y,
strand=inst.strand_y)
}}
dfab_pairs['Nanog<>Nanog'].head()
%tqdm_restart
smpdata = PerturbDataset( dfi_subset, seqs, preds, profiles, imp_scores_contrib,imp_scores_contrib_counts,
alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib,
ref_profiles,
profile_mapping)
spdata = SingleMotifPerturbDataset(smpdata).load_all(num_workers=20)
dfsm = pd.DataFrame(flatten(spdata), index=dfi_subset.index)
dfsm = pd.concat([dfsm, dfi_subset], axis=1)
dfsm.head()
from basepair.config import test_chr
np.log10(1+dfsm[['ref_obs_inside', 'ref_pred_inside']]).plot.scatter("ref_obs_inside", "ref_pred_inside", alpha=0.05, s=1)
np.log10(1+dfsm[dfsm.example_chrom.isin(test_chr)][['ref_obs_inside', 'ref_pred_inside']]).plot.scatter("ref_obs_inside", "ref_pred_inside", alpha=0.05, s=1)
dfsm.plot.scatter("ref_imp_inside", "imp_weighted", alpha=0.1, s=1)
plt.scatter(dfsm.imp_weighted, np.log10(dfsm.ref_pred_inside), alpha=0.1, s=1)
dfsm.plot.scatter("ref_pred_match", "diff_pred_match", alpha=0.1, s=1)
dfsm['log_diff_pred_inside'] = np.log10(1+np.abs(dfsm.diff_pred_inside)) * np.sign(dfsm.diff_pred_inside)
dfsm.plot.scatter("log_diff_pred_inside", "diff_imp_inside", alpha=0.05, s=1)
class OtherMotifPerturbDataset(Dataset):
def __init__(self, pdata, dfab):
self.pdata = pdata
self.dfab = dfab
def __len__(self):
return len(self.dfab)
def __getitem__(self, idx):
xidx = int(self.dfab.iloc[idx].row_idx_x)
yidx = int(self.dfab.iloc[idx].row_idx_y)
return {"xy": self.pdata.get_change(xidx, yidx), # mutate x, measure y
"yx": self.pdata.get_change(yidx, xidx)} # mutate y, measure x
motif_pair = ['Nanog', 'Klf4']
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
print(f"{i+1}/{len(pairs)}")
dfab = motif_pair_dfi(dfi_subset, motif_pair)
opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
dfab_sm = pd.DataFrame(flatten(opdata), index=dfab.index)
dfab_sm = pd.concat([dfab, dfab_sm], axis=1)
dfab_pairs["<>".join(motif_pair)] = dfab_sm
dfab_pairs_bak = deepcopy(dfab_pairs)
dfab_pairs = deepcopy(dfab_pairs_bak)
%tqdm_restart
# add also the double perturbations
for k, dfab in dfab_pairs.items():
dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
dalt_preds = bpnet.predict(dpdata_seqs)
dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles).load_all(num_workers=20)
df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
dfab = pd.concat([dfab, df_dpdata], axis=1)
dfab_pairs[k] = dfab # override the new dfab
# append A|dA
dfsm_prefixed_x = dfsm.copy()
dfsm_prefixed_y = dfsm.copy()
dfsm_prefixed_x.columns = ["dx_x_" + c for c in dfsm_prefixed_x.columns]
dfsm_prefixed_y.columns = ["dy_y_" + c for c in dfsm_prefixed_y.columns]
for k, dfab in dfab_pairs.items():
dfab = pd.merge(dfab, dfsm_prefixed_x, how='left', left_on='row_idx_x', right_on='dx_x_row_idx')
dfab = pd.merge(dfab, dfsm_prefixed_y, how='left', left_on='row_idx_y', right_on='dy_y_row_idx')
dfab_pairs[k] = dfab # override the new dfab
# store the pairs
write_pkl(dfab_pairs, modisco_dir / 'dfab_pairs.pkl')
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
# Diff (X -> Y, Y -> X)
dfab_sm.head()
# reference (X->X)
dfsm.head()
fig = plt.figure(figsize=get_figsize(.5))
dfab_pairs['Oct4-Sox2<>Sox2'].center_diff.plot.hist(30)
# Plot all pairs
mkdir -p {ddir}/figures/modisco/spacing/preturb
figures = Path(f"{ddir}/figures/modisco/spacing/preturb")
# dfab_sm.
from scipy.stats import wilcoxon
def plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=None, xl=[0, 2], pval=True):
ax.scatter(x_alt / x_ref,
y_alt / y_ref, alpha=alpha, s=s, label=label)
if pval:
xpval = wilcoxon(x_ref, x_alt).pvalue
ypval = wilcoxon(y_ref, y_alt).pvalue
kwargs = dict(size="small", horizontalalignment='center')
ax.text(1.8, 1, f"{xpval:.2g}", **kwargs)
ax.text(1, 1.8, f"{xpval:.2g}", **kwargs)
alpha = .5
ax.plot(xl, xl, c='grey', alpha=alpha)
ax.axvline(1, c='grey', alpha=alpha)
ax.axhline(1, c='grey', alpha=alpha)
ax.set_xlim(xl)
ax.set_ylim(xl);
dfab_pairs['Nanog<>Nanog'].head()
def plt_diag(xl, ax=None):
if ax is None:
ax = plt.gca()
ax.set_xlim(xl)
ax.set_ylim(xl)
ax.plot(xl, xl, c='grey', alpha=0.5);
import warnings
warnings.filterwarnings("ignore")
k = 'Oct4-Sox2<>Nanog'
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))
variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
plt.subplot(2, nc, 1)
plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dx&dy")
plt.ylabel("y_total | dx")
plt_diag([0, 500])
plt.title(k);
plt.subplot(2, nc, 2)
plt.scatter((dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total")
plt.ylabel("y_total | dx")
plt_diag([0, 1000])
plt.subplot(2, nc, 3)
plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total | dx")
plt_diag([0, 1000])
plt.subplot(2, nc, 4)
plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total")
plt_diag([0, 1000])
plt.subplot(2, nc, nc+1)
plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dx&dy")
plt.ylabel("y_footprint | dx")
plt_diag([0, 200])
plt.title(k);
plt.subplot(2, nc, nc+2)
plt.scatter(dfab_sm.yx_ref_pred_inside , dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint")
plt.ylabel("y_footprint | dx")
plt_diag([0, 400])
plt.subplot(2, nc, nc+3)
plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint | dx")
plt_diag([0, 400])
plt.subplot(2, nc, nc+4)
plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.yx_ref_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint")
plt_diag([0, 400])
plt.subplot(2, nc, nc+5)
plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.yx_ref_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy&dx")
plt.ylabel("y_footprint")
plt_diag([0, 400])
plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2)
plt.xlabel("y_total | dx&dy")
plt.ylabel("y_total | dx")
plt_diag([0, 500])
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside, dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside, s=2, alpha=0.2)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_imp'
dfab_sm = dfab_sma
for j, ax in enumerate(axes[i]):
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
# if j == 1:
# plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, xl=[-5, 5])
# else:
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
#if j == nf - 1 and i == 0:
# ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
# ax.legend(labels=list(cat_strand.categories))
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_imp'
dfab_sm = dfab_sma
for j, ax in enumerate(axes[i]):
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA&dB)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - dfab_sm.dxy_y_pred_total
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - dfab_sm.dxy_x_pred_total
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA&dB)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dxy_y_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dxy_x_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
# if j == 1:
# plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, xl=[-5, 5])
# else:
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
#if j == nf - 1 and i == 0:
# ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
# ax.legend(labels=list(cat_strand.categories))
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_imp'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_imp'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .4
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
imp_threshold = .2
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
variable = 'cat_match'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_dist'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 35]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_strand'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[(eval(variable) == cat) & (cat_imp == 'high-high')]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.short-range.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.short-range.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 35]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .4
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
imp_threshold = .2
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
variable = 'cat_match'
for j, ax in enumerate(axes[i]):
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[(eval(variable) == cat) & (cat_imp == 'high-high')]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=False)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
if j == nf - 1 and i == 0:
ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.short-range.color=imp.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.short-range.color=imp.png', raster=True, transparent=False)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[(dfab_sma.center_diff > 35) & (dfab_sma.center_diff <= 70)]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .4
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > match_threshold).map({True: 'high', False: 'low'})))
imp_threshold = .2
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > imp_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > imp_threshold).map({True: 'high', False: 'low'})))
cat = None
for j, ax in enumerate(axes[i]):
# for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[(cat_imp == 'high-high')]
if j == 0:
if i == 0:
ax.set_title("Corrected total counts")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts")
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Total counts")
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Footprint counts")
x_alt = dfab_sm.xy_alt_pred_inside
x_ref = dfab_sm.xy_ref_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside
# TODO
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
# elif j == 4:
# if i == 0:
# ax.set_title("Count importance")
# ax.scatter(dfab_sm.xy_alt_impcount_inside / dfab_sm.xy_ref_impcount_inside / 2 , # / 2 seems to fix the scatterplots for an unknown reason
# dfab_sm.yx_alt_impcount_inside / dfab_sm.yx_ref_impcount_inside / 2 , alpha=0.2, s=1)
elif j == 6:
if i == 0:
ax.set_title("Match of the footprint")
x_alt = dfab_sm.xy_alt_pred_match
x_ref = dfab_sm.xy_ref_pred_match
y_alt = dfab_sm.yx_alt_pred_match
y_ref = dfab_sm.yx_ref_pred_match
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat, pval=True)
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
plt.tight_layout()
plt.savefig(figures / 'pairwise_all.mid-range.pdf', raster=True)
plt.savefig(figures / 'pairwise_all.mid-range.png', raster=True, transparent=False)
total_examples = len(dfi.example_idx.unique())
total_examples
# TODO - generalize this table to also have the diagonal in
motif_pair = ['Nanog', 'Nanog']
dfiab = dfab_pairs["<>".join(motif_pair)]
x_total = dfi_subset[dfi_subset.pattern_name == motif_pair[0]].shape[0]
xy_total = len(dfiab[dfiab.center_diff < 150].row_idx_x.unique())
x_total # total number of instances of motif A
xy_total
dfab_pairs_filt = {k: v[v.center_diff < 150] for k,v in dfab_pairs.items()}
def plot_coocurence_matrix(dfi, dfiab_pairs, pairs, signif_threshold=1e-5, ax=None):
"""Test for motif co-occurence in example regions
"""
if ax is None:
ax = plt.gca()
from sklearn.metrics import matthews_corrcoef
from scipy.stats import fisher_exact
import statsmodels as sm
import seaborn as sns
import matplotlib.pyplot as plt
motifs = list({x for p in pairs for x in p})
idx_to_motif = {i: m for i,m in enumerate(motifs)}
o = np.zeros((len(motifs), len(motifs)))
op = np.zeros((len(motifs), len(motifs)))
totals = {m: dfi[dfi.pattern_name == m].shape[0]
for m in motifs}
for motif_pair in enumerate(pairs):
i = idx_to_motif[motif_pair[0]]
j = idx_to_motif[motif_pair[1]]
dfiab = dfab_pairs["<>".join(motif_pair)]
x_total = totals[motif_pair[0]]
y_total = totals[motif_pair[1]]
x_closeto_y_total = len(dfiab.row_idx_x.unique())
y_closeto_x_total = len(dfiab.row_idx_y.unique())
ct = [[a, b],
[c, d]]
ct = pd.crosstab(c[xn], c[yn])
# add not-counted 0 values:
ct.iloc[0, 0] += total_examples - len(c)
t22 = sm.stats.contingency_tables.Table2x2(ct)
o[i, j] = np.log2(t22.oddsratio)
op[i, j] = t22.oddsratio_pvalue()
signif = op < signif_threshold
a = np.zeros_like(signif).astype(str)
a[signif] = "*"
a[~signif] = ""
np.fill_diagonal(a, '')
sns.heatmap(pd.DataFrame(o, columns=ndxs, index=ndxs),
annot=a, fmt="", vmin=-4, vmax=4,
cmap='RdBu_r', ax=ax)
ax.set_title(f"Log2 odds-ratio. (*: p<{signif_threshold})")
def compute_features(dfab_sm):
return {"Corrected total counts": dict(
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total, # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside),
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total,
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
),
"Corrected footprint counts": dict(
# (A|dB - A|dA&dB)/(A - A|dA)
x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside, # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside,
y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside,
y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
),
"Total counts": dict(
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside),
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside),
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside),
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)
),
"Footprint counts": dict(
x_alt = dfab_sm.xy_alt_pred_inside,
x_ref = dfab_sm.xy_ref_pred_inside,
y_alt = dfab_sm.yx_alt_pred_inside,
y_ref = dfab_sm.yx_ref_pred_inside
),
"Profile importance": dict(
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)),
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)),
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)),
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
),
"Norm. profile importance": dict(
x_alt = dfab_sm.xy_alt_imp_inside,
x_ref = dfab_sm.xy_ref_imp_inside,
y_alt = dfab_sm.yx_alt_imp_inside,
y_ref = dfab_sm.yx_ref_imp_inside
),
"Match of the footprint": dict(
x_alt = dfab_sm.xy_alt_pred_match,
x_ref = dfab_sm.xy_ref_pred_match,
y_alt = dfab_sm.yx_alt_pred_match,
y_ref = dfab_sm.yx_ref_pred_match
)}
features = ["Corrected total counts",
"Corrected footprint counts",
"Total counts",
"Footprint counts",
"Profile importance",
"Norm. profile importance",
"Match of the footprint"]
def plot_mutation_heatmap(dfab_pairs, pairs, motif_list, feature='Corrected footprint counts', signif_threshold=1e-5, ax=None, max_frac=2):
if ax is None:
ax = plt.gca()
motifs = motif_list
motif_to_idx = {m: i for i,m in enumerate(motifs)}
o = np.zeros((len(motifs), len(motifs)))
op = np.zeros((len(motifs), len(motifs)))
for motif_pair in pairs:
i, j = motif_to_idx[motif_pair[0]], motif_to_idx[motif_pair[1]]
dfab_sma = dfab_pairs["<>".join(motif_pair)]
dfab_sma = dfab_sma[(dfab_sma.center_diff < 150)]
features = compute_features(dfab_sma)[feature]
o[i, j] = np.mean(features['y_alt'] / features['y_ref']) # x|dy
o[j, i] = np.mean(features['x_alt'] / features['x_ref']) # y|dx
op[i,j] = wilcoxon(features['y_ref'], features['y_alt']).pvalue
op[i,j] = wilcoxon(features['x_ref'], features['x_alt']).pvalue
signif = op < signif_threshold
a = np.zeros_like(signif).astype(str)
a[signif] = "*"
a[~signif] = ""
sns.heatmap(pd.DataFrame(o, columns=["d" + x for x in motifs], index=motifs),
annot=a, fmt="", vmin=max_frac - 1, vmax=max_frac,
cmap='RdBu_r', ax=ax)
ax.set_title(f"{feature} (alt / ref) (*: p<{signif_threshold})")
norm_counts
def co_occurence_matrix(dfi_subset, query_string=""):
"""Returns the fraction of times pattern x (row) overlaps pattern y (column)
"""
from basepair.stats import norm_matrix
total_number = dfi_subset.groupby(['pattern']).size()
norm_counts = norm_matrix(total_number)
# normalization: minimum number of counts
total_number = dfi_subset.groupby(['pattern_name']).size()
norm_counts = norm_matrix(total_number)
# cross-product
dfi_filt_crossp = pd.merge(dfi_subset[['pattern_name', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center', 'example_idx']].set_index('example_idx'),
dfi_subset[['pattern_name', 'pattern_center_aln', 'pattern_strand_aln', 'pattern_center', 'example_idx']].set_index('example_idx'),
how='outer', left_index=True, right_index=True).reset_index()
# remove self-matches
dfi_filt_crossp = dfi_filt_crossp.query('~((pattern_name_x == pattern_name_y) & (pattern_center_aln_x == pattern_center_aln_y) & (pattern_strand_aln_x == pattern_strand_aln_x))')
if query_string:
dfi_filt_crossp = dfi_filt_crossp.query(query_string)
match_sizes = dfi_filt_crossp.groupby(['pattern_name_x', 'pattern_name_y']).size()
count_matrix = match_sizes.unstack(fill_value=0)
norm_count_matrix = count_matrix / norm_counts# .truediv(min_counts, axis='columns').truediv(total_number, axis='index')
norm_count_matrix = norm_count_matrix.fillna(0) # these examples didn't have any paired pattern
return count_matrix, norm_count_matrix, norm_counts
def fisher_test_coc(random_coocurrence_counts, random_coocurrence, random_coocurrence_norm, coocurrence_counts, coocurrence, coocurrence_norm):
import statsmodels as sm
cols = list(coocurrence_norm.columns)
n = len(random_coocurrence_counts)
o = np.zeros((n,n))
op = np.zeros((n,n))
for i in range(n):
for j in range(n):
ct = [[random_coocurrence_norm.iloc[i,j] - random_coocurrence_counts.iloc[i,j], coocurrence_norm.iloc[i,j] - coocurrence_counts.iloc[i,j]],
[random_coocurrence_counts.iloc[i,j], coocurrence_counts.iloc[i,j]]]
t22 = sm.stats.contingency_tables.Table2x2(np.array(ct))
o[i, j] = t22.oddsratio
op[i, j] = t22.oddsratio_pvalue()
return pd.DataFrame(o, columns=cols, index=cols), pd.DataFrame(op, columns=cols, index=cols)
def coocurrence_plot(dfi_subset, motif_list, query_string="(abs(pattern_center_aln_x- pattern_center_aln_y) <= 150)", signif_threshold=1e-5):
"""Test for co-occurence
Args:
dfi_subset: desired subset of dfi
motif_list: list of motifs used to order the heatmap
query_string: string used with df_cross.query() to detering the valid motif pairs
signif_threshold: significance threshold for Fisher's exact test
"""
c_counts, c, c_norm = co_occurence_matrix(dfi_subset, query_string=query_string)
# Generate the NULL
dfi_subset_random = dfi_subset.copy()
np.random.seed(42)
dfi_subset_random['example_idx'] = dfi_subset_random['example_idx'].sample(frac=1).values
rc_counts, rc, rc_norm = co_occurence_matrix(dfi_subset_random, query_string=query_string)
# test for significance
o, op = fisher_test_coc(rc_counts, rc, rc_norm, c_counts, c, c_norm)
# re-order
o = o[motif_list].loc[motif_list]
op = op[motif_list].loc[motif_list]
signif = op < signif_threshold
a = np.zeros_like(signif).astype(str)
a[signif] = "*"
a[~signif] = ""
sns.heatmap(o, annot=a, fmt="", vmin=0, vmax=2,
cmap='RdBu_r', ax=ax)
ax.set_title(f"odds-ratio (proximal / non-proximal) (*: p<{signif_threshold})");
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
plot_coocurence_matrix(dfi_subset[(dfi_subset.pattern_center > 400) & (dfi_subset.pattern_center < 600)], total_examples, list(motifs), ax=ax)
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs))
ax.set_ylabel("Motif of interest")
ax.set_xlabel("Motif partner");
figsize_single = get_figsize(.5, aspect=1)
fig, axes = plt.subplots(1, len(features), figsize=(figsize_single[0]*len(features), figsize_single[0]))
for i, (feat, ax) in enumerate(zip(features, axes)):
plot_mutation_heatmap(dfab_pairs, pairs, list(motifs), feat, ax=ax, max_frac=2)
plt.tight_layout()
A | dA & dB and add 2 new importance metrics to the plot (Including wilcoxon test)(A|dB - A|dA&dB ) / ( A - A|dA)(total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA) plotQuestions:
Stratification. How are the above questions influenced by
Final goal:
# Plot the profile instances
from basepair.modisco.pattern_instances import dfi2seqlets, annotate_profile
from basepair.modisco.results import Seqlet, resize_seqlets
from basepair.plot.profiles import extract_signal
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
dfi_subset = (dfi.query('match_weighted_p > .2')
.query('imp_weighted_p > 0.0')
.query('pattern_name=="Oct4-Sox2"'))
seqlets = dfi2seqlets(dfi_subset)
seqlets = resize_seqlets(seqlets, 70, seqlen=1000)
seqlet_profiles = {k: extract_signal(v, seqlets) for k,v in profiles.items()}
multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));