# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
output_dir = modisco_dir
create_tf_session(0)
patterns = read_pkl(modisco_dir / "patterns.pkl") # aligned patterns
bpnet = BPNet.from_mdir(model_dir)
from basepair.modisco.pattern_instances import load_instances, filter_nonoverlapping_intervals, plot_coocurence_matrix, align_instance_center
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.utils import flatten
from kipoi.writers import HDF5BatchWriter
from basepair.exp.chipnexus.perturb import random_seq_onehot, PerturbSeqDataset, DoublePerturbSeqDataset, PerturbDataset, SingleMotifPerturbDataset, DoubleMotifPerturbDataset, OtherMotifPerturbDataset, PerturbationDataset
from basepair.exp.chipnexus.spacing import motif_pair_dfi, plot_spacing, get_motif_pairs
pairs = get_motif_pairs(motifs)
dfi = load_instances(modisco_dir / 'instances.parq', motifs=motifs, dedup=False)
dfi = filter_nonoverlapping_intervals(dfi)
mr = ModiscoResult(modisco_dir / 'modisco.h5')
mr.open()
# Add aligned instances
orig_patterns = [mr.get_pattern(pname) for pname in mr.patterns()]
dfi = align_instance_center(dfi, orig_patterns, patterns, trim_frac=0.08)
from basepair.cli.imp_score import ImpScoreFile
imp_scores = ImpScoreFile.from_modisco_dir(modisco_dir)
profiles = imp_scores.get_profiles()
alt_seqs = PerturbSeqDataset(dfi_subset, seqs).load_all()
alt_preds = bpnet.predict(alt_seqs)
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True, batch_size=256)
alt_imp_scores_contrib = {k: v * alt_seqs for k,v in alt_imp_scores.items()}
Implement a function to go from seqlet -> ref / alt profile
pattern = 'metacluster_0/pattern_0'
from basepair.exp.chipnexus.simulate import profile_sim_metrics
from basepair.exp.chipnexus.perturb import get_reference_profile
ref_profiles = {p: get_reference_profile(mr, longer_pattern(sn), profiles, tasks) for p,sn in motifs.items()}
task# get predictions
seqs = imp_scores.get_seq()
%time preds = bpnet.predict(seqs)
# get importance scores
%time ref_imp_scores = bpnet.imp_score_all(seqs, method='deeplift', aggregate_strand=True)
ref_imp_scores_contrib = {k: v * seqs for k,v in ref_imp_scores.items()}
ref = NumpyDataset({t: {
"obs": profiles[task],
"pred": preds[task],
"imp": {
"profile": ref_imp_scores_contrib[f"{t}/weighted"],
"count": ref_imp_scores[f"{t}/count"],
}} for t in tasks}, attrs={'index': 'example_idx'})
# get the interesting motif location
dfi_subset = (dfi.query('match_weighted_p > 0.2')
.query('imp_weighted_p > 0'))
dfi_subset['row_idx'] = np.arange(len(dfi_subset)).astype(int)
# co-occurence
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs))
ax.set_ylabel("Motif of interest")
ax.set_xlabel("Motif partner");
alt_seqs = PerturbSeqDataset(dfi_subset, seqs).load_all()
alt_preds = bpnet.predict(alt_seqs)
alt_imp_scores = bpnet.imp_score_all(alt_seqs, method='deeplift', aggregate_strand=True, batch_size=256)
single_mut = NumpyDataset({t: {
"pred": alt_preds[t],
"imp": {
"profile": alt_imp_scores_contrib[f"{t}/weighted"],
"count": alt_imp_scores_contrib[f"{t}/count"],
}} for t in tasks}, attrs={'index': 'row_idx'})
# construct the matrix of all interesting motif pairs
dfab = pd.concat([motif_pair_dfi(dfi_subset, motif_pair).assign(motif_pair='<>'.join(motif_pair))
for motif_pair in pairs], axis=0)
dfab['motif_pair_idx'] = np.arange(len(dfab))
len(dfab)
dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
double_alt_preds = bpnet.predict(dpdata_seqs)
double_alt_imp_scores = bpnet.imp_score_all(double_alt_preds, method='deeplift', aggregate_strand=True, batch_size=256)
double_alt_imp_contrib = {k: v * alt_seqs for k,v in double_alt_imp_scores.items()}
double_mut = NumpyDataset({t: {
"imp": {
"profile": double_alt_imp_contrib[f"{t}/weighted"],
"count": double_alt_imp_contrib[f"{t}/count"],
}} for t in tasks}, attrs={'index': 'motif_pair_idx'})
dataset_dir = output_dir / 'perturbation-analysis'
dataset_dir.mkdir(exist_ok=True)
!du -sh {dataset_dir}/*
%time o = NumpyDataset.load(dataset_dir / 'double_mut.h5')
!zcat {dataset_dir}/dfab.csv.gz | wc -l
# store all files to disk
%time ref.save(dataset_dir / 'ref.h5')
%time single_mut.save(dataset_dir / 'single_mut.h5')
%time double_mut.save(dataset_dir / 'double_mut.h5')
%time dfi_subset.to_csv(dataset_dir / 'dfi_subset.csv.gz', compression='gzip')
%time dfab.to_csv(dataset_dir / 'dfab.csv.gz', compression='gzip')
# Load all files from disk
%time ref = NumpyDataset.load(dataset_dir / 'ref.h5')
%time single_mut = NumpyDataset.load(dataset_dir / 'single_mut.h5')
%time double_mut = NumpyDataset.load(dataset_dir / 'double_mut.h5')
%time dfi_subset = pd.read_csv(dataset_dir / 'dfi_subset.csv.gz')
%time dfab = pd.read_csv(dataset_dir / 'dfab.csv.gz')
opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
def remove_edge_instances(dfab, profile_width=70, total_width=1000):
half = profile_width // 2 + profile_width % 2
return dfab[(dfab.pattern_center_x - half > 0) & (dfab.pattern_center_x + half < total_width)&
(dfab.pattern_center_y - half > 0) & (dfab.pattern_center_y + half < total_width)]
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile, heatmap_stranded_profile
%matplotlib inline
paper_config()
motif_pair_lpdata = {}
for motif_pair in pairs:
motif_pair_name = "<>".join(motif_pair)
dfab_subset = remove_edge_instances(dfab[dfab.motif_pair == motif_pair_name], profile_width=profile_width)
pdata = ParturbationDataset(dfab_subset, ref, single_mut, double_mut, profile_width=profile_width)
motif_pair_lpdata[motif_pair_name] = pdata.load_all(num_workers=1)
# store also dfab
motif_pair_lpdata[motif_pair_name]['dfab'] = dfab_subset
# sort_idx = np.argsort(pdata.dfab.center_diff)
write_pkl(motif_pair_lpdata, dataset_dir / 'motif_pair_lpdata.pkl')
# add also the double perturbations
for k, dfab in dfab_pairs.items():
dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
dalt_preds = bpnet.predict(dpdata_seqs)
dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles, profile_mapping).load_all(num_workers=20)
df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
dfab = pd.concat([dfab, df_dpdata], axis=1)
dfab_pairs[k] = dfab # override the new dfab
%tqdm_restart
smpdata = PerturbDataset( dfi_subset, seqs, preds, profiles, ref_imp_scores_contrib,
alt_dataset, alt_seqs, alt_preds, alt_imp_scores_contrib,
ref_profiles,
profile_mapping)
spdata = SingleMotifPerturbDataset(smpdata).load_all(num_workers=20)
dfsm = pd.DataFrame(flatten(spdata), index=dfi_subset.index)
dfsm = pd.concat([dfsm, dfi_subset], axis=1)
motif_pair = ['Nanog', 'Sox2']
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
print(f"{i+1}/{len(pairs)}")
dfab = motif_pair_dfi(dfi_subset, motif_pair)
opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
from basepair.config import test_chr
motif_pair = ['Nanog', 'Klf4']
dfab_pairs_bak = deepcopy(dfab_pairs)
dfab_pairs = {}
for i, motif_pair in enumerate(pairs):
print(f"{i+1}/{len(pairs)}")
dfab = motif_pair_dfi(dfi_subset, motif_pair)
opdata = OtherMotifPerturbDataset(smpdata, dfab).load_all(num_workers=20)
dfab_sm = pd.DataFrame(flatten(opdata), index=dfab.index)
dfab_sm = pd.concat([dfab, dfab_sm], axis=1)
dfab_pairs["<>".join(motif_pair)] = dfab_sm
dfab_pairs_bak = deepcopy(dfab_pairs)
dfab_pairs = deepcopy(dfab_pairs_bak)
%tqdm_restart
# add also the double perturbations
for k, dfab in dfab_pairs.items():
dpdata_seqs = DoublePerturbDatasetSeq(dfab, seqs).load_all(num_workers=10)
dalt_preds = bpnet.predict(dpdata_seqs)
dpdata = DoubleMotifPerturbDataset(dfab, dalt_preds, ref_profiles).load_all(num_workers=20)
df_dpdata = pd.DataFrame(flatten(dpdata), index=dfab.index)
dfab = pd.concat([dfab, df_dpdata], axis=1)
dfab_pairs[k] = dfab # override the new dfab
# append A|dA
dfsm_prefixed_x = dfsm.copy()
dfsm_prefixed_y = dfsm.copy()
dfsm_prefixed_x.columns = ["dx_x_" + c for c in dfsm_prefixed_x.columns]
dfsm_prefixed_y.columns = ["dy_y_" + c for c in dfsm_prefixed_y.columns]
for k, dfab in dfab_pairs.items():
dfab = pd.merge(dfab, dfsm_prefixed_x, how='left', left_on='row_idx_x', right_on='dx_x_row_idx')
dfab = pd.merge(dfab, dfsm_prefixed_y, how='left', left_on='row_idx_y', right_on='dy_y_row_idx')
dfab_pairs[k] = dfab # override the new dfab
# store the pairs
write_pkl(dfab_pairs, modisco_dir / 'dfab_pairs.pkl')
dfab_pairs = read_pkl(modisco_dir / 'dfab_pairs.pkl')
shall we display it using 4 differnet heatmaps or shall we always just use a single metric to display it?
focus on Oct4-Sox2 interactions
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
fig = plot_spacing(dfab_sm, alpha_scatter=0.05, y_feature='imp_weighted', figsize=get_figsize(.4, aspect=2))
mkdir -p {ddir}/figures/modisco/spacing/preturb
figures = Path(f"{ddir}/figures/modisco/spacing/preturb")
import warnings
warnings.filterwarnings("ignore")
from basepair.exp.chipnexus.perturb import plot_scatter, compute_features, plt_diag, plot_scatter, plot_pairs
plot_pairs(dfab_pairs, pairs, ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match'], pseudo_count_quantile=.2, variable=None, pval=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)
plot_pairs(dfab_pairs, pairs, ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match'], variable=None, pval=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)
cvar = 'dist'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)
plot_features = ['Total counts', 'Profile counts', 'Profile importance', 'Count importance', 'Profile match']
cvar = 'imp'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)
cvar = 'strand'
plot_pairs(dfab_pairs, pairs, plot_features, variable=f"cat_{cvar}", pval=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.pdf', raster=True)
plt.savefig(figures / f'pairwise_all.color={cvar}.png', raster=True, transparent=False)
total_examples = len(dfi.example_idx.unique())
total_examples
# TODO - generalize this table to also have the diagonal in
motif_pair = ['Nanog', 'Nanog']
dfiab = dfab_pairs["<>".join(motif_pair)]
x_total = dfi_subset[dfi_subset.pattern_name == motif_pair[0]].shape[0]
xy_total = len(dfiab[dfiab.center_diff < 150].row_idx_x.unique())
x_total # total number of instances of motif A
xy_total
dfab_pairs_filt = {k: v[v.center_diff < 150] for k,v in dfab_pairs.items()}
from basepair.exp.chipnexus.perturb import plot_mutation_heatmap
from basepair.exp.chipnexus.spacing import co_occurence_matrix, fisher_test_coc, coocurrence_plot
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs), ax=ax)
df = pd.read_csv(f"{ddir}/processed/chipnexus/external-data.tsv", sep='\t')
dfs = df[df.assay.isin(['PolII', 'H3K27ac'])]
import pybedtools
from pybedtools import BedTool
from basepair.extractors import MultiAssayExtractor
from basepair.data import NumpyDataset
df_regions = dfi_subset[['example_chrom', 'example_start', 'example_end', 'example_idx']].drop_duplicates()
bt = BedTool.from_dataframe(df_regions)
extractor = MultiAssayExtractor(dfs, None, use_strand=False, n_jobs=10)
regions = extractor.extract(list(bt))
r = NumpyDataset(regions)
dfc = pd.DataFrame(r.aggregate(np.sum, axis=1))
dfc['example_idx'] = df_regions['example_idx'].values
dfc.head()
high = np.quantile(dfc.H3K27ac, .9)
high
dfc.H3K27ac.plot.hist(30);
# Add H3K27 ac to the table
dfi_subset = pd.merge(dfi_subset, dfc, on='example_idx', how='left')
np.quantile(dfi_subset[['example_idx', 'H3K27ac']].drop_duplicates().H3K27ac, .9)
ls {figures}
!mkdir -p {figures}/../co-occurence/
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset, list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/all.pdf")
plt.savefig(f"{figures}/../co-occurence/all.png")
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset[dfi_subset.H3K27ac > np.quantile(dfc.H3K27ac, .9)], list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/H3K27ac>90percentile.pdf")
plt.savefig(f"{figures}/../co-occurence/H3K27ac>90percentile.png")
# old matrix
fig, ax = plt.subplots(figsize=get_figsize(.5, aspect=1))
coocurrence_plot(dfi_subset[dfi_subset.PolII > np.quantile(dfc.PolII, .9)], list(motifs), ax=ax)
plt.savefig(f"{figures}/../co-occurence/PolII>90percentile.pdf")
plt.savefig(f"{figures}/../co-occurence/PolII>90percentile.png")
dfi_subset.pattern_name.head()
dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])]
dfi_subset[dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])].groupby(["example_idx", 'pattern_name').pattern_name.size()
counts = pd.pivot_table(dfi_subset[dfi_subset.pattern_name.isin(["Oct4-Sox2", "Klf4"])], index='example_idx', columns='pattern_name', values='H3K27ac', aggfunc=len, fill_value=0)
c = counts > 0
values = c.Klf4.map({False:"", True: "Klf4"}) + c['Oct4-Sox2'].map({False:"", True: "Oct4-Sox2"})
values.value_counts()
pd.where(c.Klf4, "", "Klf4") + "a"
dfs = dfi_subset[['example_idx', "H3K27ac", "PolII"]].drop_duplicates()
dfs = dfs.set_index("example_idx")
dfsj = pd.DataFrame({"feat": values}).join(dfs)
dfsj.unstack(id_vars=)
ggplot(aes(x='feat', y='H3K27ac'), data=dfsj) + geom_boxplot() + scale_y_continuous(trans='log10') + geom_violin()
ggplot(aes(x='feat', y='PolII'), data=dfsj) + geom_boxplot() + scale_y_continuous(trans='log10')
values
c
figsize_single = get_figsize(.5, aspect=1)
fig, axes = plt.subplots(1, len(features), figsize=(figsize_single[0]*len(features), figsize_single[0]))
for i, (feat, ax) in enumerate(zip(features, axes)):
if i >= 2 :
max_frac = 1.5
# smaller y-scale
else:
max_frac = 2
plot_mutation_heatmap(dfab_pairs, pairs, list(motifs), feat, ax=ax, max_frac=max_frac)
plt.tight_layout()
A | dA & dB and add 2 new importance metrics to the plot (Including wilcoxon test)(A|dB - A|dA&dB ) / ( A - A|dA)(total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA) plotQuestions:
Stratification. How are the above questions influenced by
Final goal:
# Plot the profile instances
from basepair.modisco.pattern_instances import dfi2seqlets, annotate_profile
from basepair.modisco.results import Seqlet, resize_seqlets
from basepair.plot.profiles import extract_signal
from basepair.plot.profiles import plot_stranded_profile, multiple_plot_stranded_profile
from basepair.plot.heatmaps import heatmap_stranded_profile, multiple_heatmap_stranded_profile
dfi_subset = (dfi.query('match_weighted_p > .2')
.query('imp_weighted_p > 0.0')
.query('pattern_name=="Oct4-Sox2"'))
seqlets = dfi2seqlets(dfi_subset)
seqlets = resize_seqlets(seqlets, 70, seqlen=1000)
seqlet_profiles = {k: extract_signal(v, seqlets) for k,v in profiles.items()}
multiple_plot_stranded_profile({p:v for p,v in seqlet_profiles.items()}, figsize_tmpl=(2.55,2))
multiple_heatmap_stranded_profile(seqlet_profiles, sort_idx=np.arange(1000), figsize=(10,10));
k = 'Oct4-Sox2<>Nanog'
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.Categorical(pd.cut(dfab_sma.center_diff, [0, 35, 70, 150]))
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))
variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
plt.subplot(2, nc, 1)
plt.scatter(dfab_sm.dxy_y_pred_total, dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dx&dy")
plt.ylabel("y_total | dx")
plt_diag([0, 500])
plt.title(k);
plt.subplot(2, nc, 2)
plt.scatter((dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total")
plt.ylabel("y_total | dx")
plt_diag([0, 1000])
plt.subplot(2, nc, 3)
plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total | dx")
plt_diag([0, 1000])
plt.subplot(2, nc, 4)
plt.scatter((dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total")
plt_diag([0, 1000])
plt.subplot(2, nc, nc+1)
plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dx&dy")
plt.ylabel("y_footprint | dx")
plt_diag([0, 200])
plt.title(k);
plt.subplot(2, nc, nc+2)
plt.scatter(dfab_sm.xy_ref_pred_inside , dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint")
plt.ylabel("y_footprint | dx")
plt_diag([0, 400])
plt.subplot(2, nc, nc+3)
plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_alt_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint | dx")
plt_diag([0, 400])
plt.subplot(2, nc, nc+4)
plt.scatter(dfab_sm.dy_y_alt_pred_inside, dfab_sm.xy_ref_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint")
plt_diag([0, 400])
plt.subplot(2, nc, nc+5)
plt.scatter(dfab_sm.dxy_y_pred_inside, dfab_sm.xy_ref_pred_inside, s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy&dx")
plt.ylabel("y_footprint")
plt_diag([0, 400])
plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
nc = 5
nr = 2
fig = plt.figure(figsize=(get_figsize(.5, aspect=1)[0]*nc, get_figsize(.5, aspect=1)[1]*2))
variable = 'cat_imp'
for k, cat in enumerate(eval(variable).categories):
dfab_sm = dfab_sma[eval(variable) == cat]
plt.subplot(2, nc, 1)
plt.scatter(np.log(dfab_sm.dxy_y_pred_total), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dx&dy")
plt.ylabel("y_total | dx")
plt_diag([4, 7])
plt.title(k);
plt.subplot(2, nc, 2)
plt.scatter(np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total")
plt.ylabel("y_total | dx")
plt_diag([4, 7])
plt.subplot(2, nc, 3)
plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total | dx")
plt_diag([4, 7])
plt.subplot(2, nc, 4)
plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside), np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_total | dy")
plt.ylabel("y_total")
plt_diag([4, 7])
plt.subplot(2, nc, nc+1)
plt.scatter(np.log(dfab_sm.dxy_y_pred_inside), np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dx&dy")
plt.ylabel("y_footprint | dx")
plt_diag([2, 6])
plt.title(k);
plt.subplot(2, nc, nc+2)
plt.scatter(np.log(dfab_sm.xy_ref_pred_inside) , np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint")
plt.ylabel("y_footprint | dx")
plt_diag([2, 6])
plt.subplot(2, nc, nc+3)
plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside), np.log(dfab_sm.xy_alt_pred_inside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint | dx")
plt_diag([2, 6])
plt.subplot(2, nc, nc+4)
plt.scatter(np.log(dfab_sm.dy_y_alt_pred_inside), np.log(dfab_sm.xy_ref_pred_inside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy")
plt.ylabel("y_footprint")
plt_diag([2, 6])
plt.subplot(2, nc, nc+5)
plt.scatter(np.log(dfab_sm.dxy_y_pred_inside), np.log(dfab_sm.xy_ref_pred_inside), s=2, alpha=0.2, label=cat)
plt.xlabel("y_footprint | dy&dx")
plt.ylabel("y_footprint")
plt_diag([2, 6])
plt.legend(scatterpoints=1, markerscale=10, columnspacing=0, handletextpad=0, borderpad=0, frameon=False, title=variable)
plt.tight_layout()
# Note: All are on the log-scale
fig = plt.figure(figsize=get_figsize(.5))
x_alt = np.log(dfab_sm.xy_alt_pred_inside) - np.log(dfab_sm.dxy_y_pred_inside) # total counts | dB - dxy_x_pred_total
x_ref = np.log(dfab_sm.xy_ref_pred_inside) - np.log(dfab_sm.dy_y_alt_pred_inside)
plt.scatter(x_ref, x_alt, s=2, alpha=0.2)
plt.xlabel("y - y|dy")
plt.ylabel("y|dx - y|dx&dy")
plt_diag([-.5, 2])
fig = plt.figure(figsize=get_figsize(.5))
frac = x_alt / x_ref
plt.hist(frac[(frac<2) & (frac > 0)], 30);
plt.ylabel("Frequency");
plt.xlabel("(y|dx - y|dx&dy) / (y - y|dy)\n(y=Nanog, x=Oct4-Sox2)");
fig = plt.figure(figsize=get_figsize(.5))
plt.scatter(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside, dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside, s=2, alpha=0.2)
nf = 7
fig, axes = plt.subplots(nrows=len(pairs), ncols=nf, figsize=get_figsize(1/3*nf, len(pairs) / nf))
for i, motif_pair in enumerate(pairs):
k = "<>".join(motif_pair)
dfab_sma = dfab_pairs[k]
dfab_sma = dfab_sma[dfab_sma.center_diff < 150]
cat_dist = pd.cut(dfab_sma.center_diff, [0, 35, 70, 150])
cat_strand = pd.Categorical(dfab_sma.strand_combination)
match_threshold = .2
cat_match = pd.Categorical(((dfab_sma.match_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.match_weighted_p_y > .2).map({True: 'high', False: 'low'})))
cat_imp = pd.Categorical(((dfab_sma.imp_weighted_p_x > match_threshold).map({True: 'high', False: 'low'}) + "-" +
(dfab_sma.imp_weighted_p_y > .2).map({True: 'high', False: 'low'})))
variable = 'cat_imp'
dfab_sm = dfab_sma
for j, ax in enumerate(axes[i]):
if j == 0:
if i == 0:
ax.set_title("Corrected total counts (ratio)")
# (total_counts|dB - total_counts|dA&dB) / (total_counts - total_counts|dA)
# x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
# x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - (dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
# y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - dfab_sm.dxy_x_pred_total
# y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - (dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
x_alt = (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside)/dfab_sm.dxy_y_pred_total # total counts | dB - dxy_x_pred_total
x_ref = (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside)/(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside)/ dfab_sm.dxy_x_pred_total
y_ref = (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside)/(dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 1:
if i == 0:
ax.set_title("Corrected footprint counts (ratio)")
# (A|dB - A|dA&dB)/(A - A|dA)
# x_alt = dfab_sm.xy_alt_pred_inside - dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
# x_ref = dfab_sm.xy_ref_pred_inside - dfab_sm.dy_y_alt_pred_inside
# y_alt = dfab_sm.yx_alt_pred_inside - dfab_sm.dxy_x_pred_inside
# y_ref = dfab_sm.yx_ref_pred_inside - dfab_sm.dx_x_alt_pred_inside
x_alt = dfab_sm.xy_alt_pred_inside / dfab_sm.dxy_y_pred_inside # total counts | dB - dxy_x_pred_total
x_ref = dfab_sm.xy_ref_pred_inside / dfab_sm.dy_y_alt_pred_inside
y_alt = dfab_sm.yx_alt_pred_inside / dfab_sm.dxy_x_pred_inside
y_ref = dfab_sm.yx_ref_pred_inside / dfab_sm.dx_x_alt_pred_inside
if j == 2:
if i == 0:
ax.set_title("Log Corrected total counts")
x_alt = np.log(dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside) - np.log(dfab_sm.dxy_y_pred_total) # total counts | dB - dxy_x_pred_total
x_ref = np.log(dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside) - np.log(dfab_sm.dy_y_alt_pred_inside + dfab_sm.dy_y_alt_pred_outside)
y_alt = np.log(dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside) - np.log(dfab_sm.dxy_x_pred_total)
y_ref = np.log(dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside) - np.log(dfab_sm.dx_x_alt_pred_inside + dfab_sm.dx_x_alt_pred_outside)
if j == 3:
if i == 0:
ax.set_title("Log corrected footprint counts (ratio)")
x_alt = np.log(dfab_sm.xy_alt_pred_inside) - np.log(dfab_sm.dxy_y_pred_inside) # total counts | dB - dxy_x_pred_total
x_ref = np.log(dfab_sm.xy_ref_pred_inside) - np.log(dfab_sm.dy_y_alt_pred_inside)
y_alt = np.log(dfab_sm.yx_alt_pred_inside) - np.log(dfab_sm.dxy_x_pred_inside)
y_ref = np.log(dfab_sm.yx_ref_pred_inside) - np.log(dfab_sm.dx_x_alt_pred_inside)
elif j == 4:
if i == 0:
ax.set_title("Profile importance")
x_alt = (dfab_sm.xy_alt_imp_inside * (dfab_sm.xy_alt_pred_inside + dfab_sm.xy_alt_pred_outside))
x_ref = (dfab_sm.xy_ref_imp_inside * (dfab_sm.xy_ref_pred_inside + dfab_sm.xy_ref_pred_outside))
y_alt = (dfab_sm.yx_alt_imp_inside * (dfab_sm.yx_alt_pred_inside + dfab_sm.yx_alt_pred_outside))
y_ref = (dfab_sm.yx_ref_imp_inside * (dfab_sm.yx_ref_pred_inside + dfab_sm.yx_ref_pred_outside))
elif j == 5:
if i == 0:
ax.set_title("Norm. profile importance")
x_alt = dfab_sm.xy_alt_imp_inside
x_ref = dfab_sm.xy_ref_imp_inside
y_alt = dfab_sm.yx_alt_imp_inside
y_ref = dfab_sm.yx_ref_imp_inside
elif j == 6:
if i == 0:
ax.set_title("Deeplift w.r.t. counts")
x_alt = dfab_sm.xy_alt_impcount_inside
x_ref = dfab_sm.xy_ref_impcount_inside
y_alt = dfab_sm.yx_alt_impcount_inside
y_ref = dfab_sm.yx_ref_impcount_inside
plot_scatter(x_ref, x_alt, y_ref, y_alt, ax, alpha=.2, s=1, label=cat)
#if j == nf - 1 and i == 0:
# ax.legend(scatterpoints=1, ncol=2, markerscale=10, columnspacing=0, loc='upper right', handletextpad=0, borderpad=0, frameon=False, title=variable)
# ax.legend(labels=list(cat_strand.categories))
ax.set_xlabel(r"${}\;(\Delta {})$".format(motif_pair[1], motif_pair[0]))
ax.set_ylabel(r"${}\;(\Delta {})$".format(motif_pair[0], motif_pair[1]))
# plt.suptitle("mutated / ref", y=1.02);
plt.tight_layout()
# plt.savefig(figures / 'pairwise_all.wilcox_pval.pdf', raster=True)
# plt.savefig(figures / 'pairwise_all.wilcox_pval.png', raster=True, transparent=False)