from basepair.imports import *
from basepair.exp.paper.config import *
fdir = Path(f'{ddir}/figures/method-comparison/vplot')
fdir.mkdir(exist_ok=True)
from basepair.cli.schemas import DataSpec, TaskSpec
from pybedtools import BedTool
from basepair.preproc import resize_interval
from basepair.plot.evaluate import regression_eval
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
ATAC_PROC_DIR = Path('/oak/stanford/groups/akundaje/projects/bpnet/ATACseq_processed')
CROO_DIR = ATAC_PROC_DIR / 'croo'
OUTPUT_DIR = Path('/oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/atac')
METHODS = ['BPNet', 'MEME/FIMO']
dfi_list_dir = '/oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/chexmix-peakxus'
closest_motifs = read_pkl("/oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/chexmix-peakxus/closest_motifs.pkl")
motifs = list(closest_motifs)
df_id = pd.read_csv(ATAC_PROC_DIR / 'id_to_sample.txt', sep='\t',
usecols=['id', 'status', 'name', 'str_label'])
sample_to_id = {row.str_label: row.id for i,row in df_id.iterrows()}
samples = list(sample_to_id)
fasta_file = '/oak/stanford/groups/akundaje/avsec/basepair/data/processed/comparison/data/mm10_no_alt_analysis_set_ENCODE.fasta'
main_motifs = ['Oct4-Sox2', 'Sox2', 'Nanog', 'Klf4']
### Load the data
# dfi_list
dfi_list = read_pkl(f'{dfi_list_dir}/dfi_list.incl-conservation,is_erv.pkl')
dfi = dfi_list['BPNet']['dfi']
dict(dfi.example_chrom.value_counts())
# ChIP-nexus profiles, ranges
exp = 'nexus,peaks,OSNK,0,10,1,FALSE,same,0.5,64,25,0.004,9,FALSE,[1,50],TRUE'
model_dir = models_dir / exp
isf = ImpScoreFile(model_dir / 'deeplift.imp_score.h5', default_imp_score='profile/wn')
profiles = isf.get_profiles()
ranges = isf.get_ranges()
seqs = isf.get_seq()
contrib = isf.get_contrib()
np.argmax(seqs[:5, :5], axis=-1)
contrib_size = {task: v.sum(axis=-1) for task, v in contrib.items()}
from basepair.preproc import moving_average
_x = np.ravel(moving_average(contrib_size['Oct4'].swapaxes(0, 1), n=10).swapaxes(0,1))
np.mean(_x > 0.1)
from kipoi.data import Dataset
class MotifMask(Dataset):
def __init__(self, dfi, seqs, index_col='example_idx'):
self.dfi = dfi.copy().set_index(index_col, drop=False)
self.seqs = seqs
self.index_col = index_col
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
out = np.zeros((self.seqs.shape[1])).astype(bool)
# Get instances
if idx not in self.dfi.index:
# no rows. Return the same sequence.
return out
instances = self.dfi.loc[[idx]]
assert np.all(instances[self.index_col] == idx)
# asign motif locations
for i,inst in instances.iterrows():
seqlen = inst.pattern_end - inst.pattern_start
out[int(inst.pattern_start):int(inst.pattern_end)] = True
return out
all_motif_mask = MotifMask(dfi, seqs).load_all(num_workers=10)
dsox2_motifs = ['Oct4-Sox2', 'Sox2']
doct4_motifs = ['Oct4-Sox2', 'Oct4', 'Oct4-Oct4']
sox2_mask = MotifMask(dfi[dfi.pattern_name.isin(dsox2_motifs)], seqs).load_all(num_workers=10)
oct4_mask = MotifMask(dfi[dfi.pattern_name.isin(doct4_motifs)], seqs).load_all(num_workers=10)
all_motif_mask = MotifMask(dfi, seqs).load_all(num_workers=10)
print("Sox2 motifs cover", sox2_mask.mean())
print("Oct4 motifs cover", oct4_mask.mean())
print("All motifs cover", all_motif_mask.mean())
from basepair.modisco.sliding_similarities import pad_same
# ~40% of all sites with high importance is covered by motifs
t_list = [0.5, 0.3, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
for task in tasks:
_x = moving_average(contrib_size[task].swapaxes(0, 1), n=10).swapaxes(0,1)
plt.plot([np.mean(_x>t) for t in t_list],
[np.sum((_x>t) & all_motif_mask) / np.sum(_x>t) for t in t_list], label=task)
plt.xlabel("Fraction of sites called important (higher=less stringent)")
plt.ylabel("Fraction of important bases located in motif instances")
plt.legend()
t_list = [0.5, 0.3, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
for task in tasks:
_x = moving_average(contrib_size[task].swapaxes(0, 1), n=10).swapaxes(0,1)
plt.plot(t_list, [np.sum((_x>t) & all_motif_mask) / np.sum(_x>t) for t in t_list], label=task)
plt.xlabel("Importance score threshold (higher=more stringent)")
plt.ylabel("Fraction of important bases located in motif instances")
plt.legend()
t_list = [0.5, 0.3, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
for task in tasks:
_x = moving_average(contrib_size[task].swapaxes(0, 1), n=10).swapaxes(0,1)
plt.plot(t_list, [np.sum((_x>t) & sox2_mask) / np.sum(_x>t) for t in t_list], label=task)
plt.legend()
plt.xlabel("Importance score threshold (higher=more stringent)")
plt.ylabel("Fraction of important bases located in motif instances")
plt.title("Sox2 motif mask")
t_list = [0.5, 0.3, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
for task in tasks:
_x = moving_average(contrib_size[task].swapaxes(0, 1), n=10).swapaxes(0,1)
plt.plot(t_list, [np.sum((_x>t) & oct4_mask) / np.sum(_x>t) for t in t_list], label=task)
plt.legend()
plt.xlabel("Importance score threshold (higher=more stringent)")
plt.ylabel("Fraction of important bases located in motif instances")
plt.title("Oct4 motif mask")
ranges['example_idx'] = ranges['idx']
# Centered at the ChIP-nexus data
ds = DataSpec(task_specs={sample: TaskSpec(pos_counts=f'{OUTPUT_DIR}/{sample}/coverage.sub150bp-fragments.pos.bw',
neg_counts=f'{OUTPUT_DIR}/{sample}/coverage.sub150bp-fragments.neg.bw',
task=sample)
for sample in samples},
fasta_file=fasta_file)
ds.touch_all_files()
from basepair.extractors import Interval
ranges
from pybedtools import BedTool
all_intervals = list(BedTool.from_dataframe(ranges[['chrom', 'start', 'end']]))
tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4']
atac_profile = ds.load_counts(all_intervals)
def atac_peak_path(sample):
return f'{CROO_DIR}/{sample_to_id[sample]}/peak/idr_reproducibility/idr.optimal_peak.narrowPeak.gz'
atac_peak_paths = {sample: atac_peak_path(sample) for sample in samples}
# Use new code from pyranges
import pyranges as pr
from pyranges import PyRanges
def read_bed(f, output_df=False, nrows=None):
columns = "Chromosome Start End Name Score Strand ThickStart ThickEnd ItemRGB BlockCount BlockSizes BlockStarts".split(
)
if f.endswith(".gz"):
import gzip
first_start = gzip.open(f).readline().split()[1]
else:
first_start = open(f).readline().split()[1]
header = None
try:
int(first_start)
except ValueError:
header = 0
df = pd.read_csv(
f,
dtype={
"Chromosome": "category",
"Strand": "category"
},
nrows=nrows,
header=header,
sep="\t")
df.columns = columns[:df.shape[1]]
if not output_df:
return PyRanges(df)
else:
return df
for sample in samples:
atac_peaks = read_bed(atac_peak_path(sample))
atac_peaks = atac_peaks[['Chromosome', 'Start', 'End', 'ThickStart', 'BlockCount']]
# ThickStart -> signalValue
# BlockCount -> peak
atac_peaks = atac_peaks.set_columns(['Chromosome', 'Start', 'End', 'signal_value', 'peak_summit'])
ranges_pr = pr.PyRanges(ranges.rename(columns={'chrom': 'Chromosome', 'start': 'Start', 'end': 'End', 'idx': 'example_idx'}))
ranges[f'overlaps_{sample}'] = ranges['idx'].isin(np.unique(ranges_pr.intersect(atac_peaks).example_idx))
ranges[[c for c in ranges if c.startswith('overlaps_')]].mean()
from basepair.plot.tracks import plot_tracks
plot_tracks({sample.replace('ATAC_', ''): counts.mean(axis=0).sum(axis=-1) for sample,counts in atac_profile.items()},
fig_width=5,
fig_height_per_track=.7,
same_ylim=True,
rotate_y=0,
title=f'ATAC tracks (average number of counts per position)');
plt.xlabel('Position')
sns.despine(top=True, right=True)
# Append total number of counts to ranges
for sample,counts in atac_profile.items():
ranges[sample] = counts.sum(axis=(1,2))
# tidy table
ranges_melt = ranges.melt(id_vars=[c for c in ranges.columns if c not in samples],
value_vars=samples)
ranges_melt['log10_value'] = np.log10(1+ranges_melt['value'])
ranges_melt['variable'] = ranges_melt['variable'].str.replace('ATAC_', '')
ranges_melt
from basepair.plot.utils import plt9_tilt_xlab
(ggplot(aes(x='variable', y='log10_value', color='overlaps_atac_peak'), data=ranges_melt)
+ facet_grid('interval_from_task ~ .')
+ geom_boxplot()
+ scale_color_brewer('qual')
+ theme_classic()
+ plt9_tilt_xlab()
)
from basepair.exp.chipnexus.perturb.gen import random_seq_onehot
from kipoi.data import Dataset
class PerturbedMotifsSeq(Dataset):
def __init__(self, dfi, seqs, index_col='example_idx'):
self.dfi = dfi.copy().set_index(index_col, drop=False)
self.seqs = seqs
self.index_col = index_col
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
ref_seq = self.seqs[idx]
# generate the alternative sequence
alt_seq = ref_seq.copy()
# Get instances
if idx not in self.dfi.index:
# no rows. Return the same sequence.
return alt_seq
instances = self.dfi.loc[[idx]]
assert np.all(instances[self.index_col] == idx)
# mutate alternative sequence
for i,inst in instances.iterrows():
seqlen = inst.pattern_end - inst.pattern_start
alt_seq[int(inst.pattern_start):int(inst.pattern_end)] = random_seq_onehot(seqlen)
return alt_seq
motifs
assert len(ranges) == len(seqs)
seqs_random = seqs[:, np.random.permutation(1000), :]
# Fraction of sequences that have a motifs
print("dOct4", len(dfi[dfi.pattern_name.isin(doct4_motifs)].example_idx.unique()) / len(seqs))
print("dSox2", len(dfi[dfi.pattern_name.isin(dsox2_motifs)].example_idx.unique()) / len(seqs))
seq_dict = dict(wt=seqs,
dsox2=PerturbedMotifsSeq(dfi[dfi.pattern_name.isin(dsox2_motifs)], seqs).load_all(num_workers=10),
doct4=PerturbedMotifsSeq(dfi[dfi.pattern_name.isin(doct4_motifs)], seqs).load_all(num_workers=10))
from basepair.BPNet import BPNetSeqModel
create_tf_session(2)
m = BPNetSeqModel.from_mdir(model_dir)
seq_dict['wt'].shape
pred_dict = {k: m.predict(seq) for k,seq in seq_dict.items()}
pred_dict['random'] = m.predict(seqs_random)
pred_total = pd.concat([pd.DataFrame({task: v.sum(axis=(1,2)) for task,v in d.items()}).assign(seq=k, idx=np.arange(len(seqs)))
for k,d in pred_dict.items()])
dfpm = pred_total.melt(id_vars=['seq', 'idx'], var_name='task', value_name='total_counts')
dfpm['log_total_counts'] = np.log10(1+dfpm['total_counts'])
dfp = dfpm.pivot_table(index=['idx', 'task'], columns='seq', values='log_total_counts').reset_index()
dfp
a=1
# TODO - to how many regions does sox2 bind where there are no motif instances?
same_seqs_dsox2 = np.all(seq_dict['wt'] == seq_dict['dsox2'], axis=(1,2))
same_seqs_dsox2.mean()
same_seqs_doct4 = np.all(seq_dict['wt'] == seq_dict['doct4'], axis=(1,2))
same_seqs_doct4.mean()
dfpm['same_seqs_doct4'] = same_seqs_doct4[dfpm.idx]
dfpm['same_seqs_dsox2'] = same_seqs_dsox2[dfpm.idx]
plotnine.options.figure_size = get_figsize(.7, 1/3)
g = (ggplot(aes(color='same_seqs_doct4', x='log_total_counts'), dfpm[((dfpm['task'] == 'Oct4') & (dfpm['seq'] == 'wt'))])
+ geom_density()
+ ggtitle("dOct4")
# + geom_boxplot()
# + facet_grid("task ~ .")
+ theme_classic()
)
display(g)
plotnine.options.figure_size = get_figsize(.7, 1/3)
g = (ggplot(aes(color='same_seqs_dsox2', x='log_total_counts'), dfpm[((dfpm['task'] == 'Sox2') & (dfpm['seq'] == 'wt'))])
+ geom_density()
+ ggtitle("dSox2")
# + geom_boxplot()
# + facet_grid("task ~ .")
+ theme_classic()
)
display(g);
# Large majority of sequences is the same
dfp
dfp['same_seqs_doct4'] = same_seqs_doct4[dfp.idx]
dfp['same_seqs_dsox2'] = same_seqs_dsox2[dfp.idx]
fig, axes = plt.subplots(1, 3, figsize=get_figsize(1, 1/4),
gridspec_kw=dict(wspace=0.2),
sharex=True, sharey=True)
ax = axes[0]
regression_eval(dfp[dfp['task'] == 'Sox2']['random'],
dfp[dfp['task'] == 'Sox2']['dsox2'], task=f'Sox2', alpha=0.1, ax=ax);
ax.plot([1, 3], [1,3], '--', color='gray')
ax.set_ylabel("Random")
ax.set_xlabel("dsox2");
ax = axes[1]
regression_eval(dfp[dfp['task'] == 'Sox2']['wt'],
dfp[dfp['task'] == 'Sox2']['dsox2'], task=f'Sox2', alpha=0.1, ax=ax);
ax.plot([1, 3], [1,3], '--', color='gray')
ax.set_ylabel("WT")
ax.set_xlabel("dsox2");
ax = axes[2]
regression_eval(dfp[dfp['task'] == 'Sox2']['random'],
dfp[dfp['task'] == 'Sox2']['wt'], task=f'Sox2', alpha=0.1, ax=ax);
ax.plot([1, 3], [1,3], '--', color='gray')
ax.set_ylabel("Random")
ax.set_xlabel("WT");
plotnine.options.figure_size = get_figsize(.7, 1)
(ggplot(aes(x='log_total_counts', color='seq'), dfpm)
# + geom_histogram(bins=100)
+ geom_density()
+ facet_grid("task ~ .")
+ theme_classic()
)
dfpm
pred_max_profile = {k: {task: (v / v.sum(axis=1, keepdims=True)).max(axis=1).mean(axis=-1) for task,v in d.items()}
for k,d in pred_dict.items()}
# X_bpnet_dsox2 = pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['dsox2'])
# X_bpnet_doct4 = pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['doct4'])
X_feat = dfi.groupby(['example_idx', 'pattern_name']).size().unstack(fill_value=0).reset_index()
X_feat = pd.merge(ranges[['example_idx']], X_feat, on='example_idx', how='left').fillna(0)
del X_feat['example_idx']
def prefix_name(df, prefix):
df = df.copy()
df.columns = [prefix + c for c in df.columns]
return df
ranges = pd.concat([ranges, prefix_name(X_feat, 'motif_counts_')], axis=1)
from basepair.exp.chipnexus.spacing import remove_edge_instances, get_motif_pairs, motif_pair_dfi
# Generate motif pairs
pairs = get_motif_pairs(motifs)
# ordered names
pair_names = ["<>".join(x) for x in pairs]
# create motif pairs
dfab = pd.concat([motif_pair_dfi(dfi[(~dfi.is_te) & (~dfi.is_erv)], pair).assign(motif_pair="<>".join(pair)) for pair in pairs], axis=0)
# Remove self matches
dfab = dfab.query('~((pattern_center_aln_x == pattern_center_aln_y) & (pattern_strand_aln_x == pattern_strand_aln_x))')
exclude_sox2 = dfab[(dfab.motif_pair == 'Oct4-Sox2<>Sox2') &
(dfab['center_diff_aln'] == 0)].row_idx_y.values
exclude_oct4 = dfab[(dfab.motif_pair == 'Oct4-Sox2<>Oct4') &
(dfab['center_diff_aln'] == 0)].row_idx_y.values
exclude_oct4_v2 = dfab[(dfab.motif_pair == 'Oct4<>Oct4-Oct4') &
(dfab['center_diff_aln'] == 0)].row_idx_y.values
old_len = len(dfab)
# Exclude the overlapping row
dfab = dfab[(dfab.pattern_name_x != 'Oct4') | (~dfab.row_idx_x.isin(exclude_oct4))]
dfab = dfab[(dfab.pattern_name_y != 'Oct4') | (~dfab.row_idx_y.isin(exclude_oct4))]
dfab = dfab[(dfab.pattern_name_x != 'Oct4') | (~dfab.row_idx_x.isin(exclude_oct4_v2))]
dfab = dfab[(dfab.pattern_name_y != 'Oct4') | (~dfab.row_idx_y.isin(exclude_oct4_v2))]
dfab = dfab[(dfab.pattern_name_x != 'Sox2') | (~dfab.row_idx_x.isin(exclude_sox2))]
dfab = dfab[(dfab.pattern_name_y != 'Sox2') | (~dfab.row_idx_y.isin(exclude_sox2))]
nd = len(dfab) - old_len
print(f"Removed {nd}/{len(dfab)} instances")
nexus_counts = pd.DataFrame({k:v.sum(axis=(1,2)) for k,v in profiles.items()})
ranges = pd.concat([ranges, nexus_counts], axis=1)
tasks
from scipy.stats import pearsonr, spearmanr
df_cor_list = []
for sample in samples:
for task in tasks:
y_true = ranges[task]+1
y_pred = ranges[sample]+1
spearman, spearman_pval = spearmanr(np.log10(y_true), np.log(y_pred))
df_cor_list.append(dict(sample=sample, task=task, spearman=spearman))
df_cor = pd.DataFrame(df_cor_list)
df_cor['tf_time'] = df_cor['sample'].map(lambda x: x[5:9] + '_' + x.split('_')[-1])
df_cor['TF_expressed'] = df_cor['sample'].map(lambda x: x.split('_')[1][4:])
from basepair.plot.utils import plt9_tilt_xlab
plotnine.options.figure_size = get_figsize(.7, 1/3)
(ggplot(aes(x='task', y='spearman', fill='TF_expressed'), df_cor)
+ geom_bar(stat='identity', position='dodge')
+ facet_grid(".~tf_time")
+ scale_fill_brewer('qual', 4)
+ theme_classic()
+ plt9_tilt_xlab()
+ ggtitle("Spearman correlation: ATAC ~ ChIP-nexus")
)
for sample in samples:
fig, axes = plt.subplots(1, len(tasks),
figsize=get_figsize(1, aspect=1/len(tasks)),
gridspec_kw=dict(wspace=0),
sharex=True,
sharey=True)
for i, (ax, task) in enumerate(zip(axes, tasks)):
regression_eval(ranges[task]+1, ranges[sample]+1, loglog=True, task=task, alpha=0.2, ax=ax);
if i == 0:
ax.set_ylabel(sample)
else:
ax.set_ylabel(None)
ax.set_xlabel("ChIP-nexus")
ranges
ranges
_ranges_melt
_ranges_melt = ranges.melt(value_vars=samples)
_ranges_melt['log10_counts'] = np.log10(1 + _ranges_melt['value'])
_ranges_melt['variable'] = pd.Categorical(_ranges_melt['variable'], samples, ordered=True)
_ranges_melt['variable'].cat.rename_categories(lambda x: x.replace('ATAC_', ''), inplace=True)
from basepair.plot.utils import plt9_horizontal_facet_label, plt9_remove_facet_label_box
plotnine.options.figure_size = get_figsize(.7, 1)
(ggplot(aes(x='log10_counts'), _ranges_melt)
+ geom_histogram(bins=100)
+ facet_grid('variable ~ .')
+ theme_classic()
+ plt9_horizontal_facet_label()
+ plt9_remove_facet_label_box()
+ xlim(0, 3.5)
+ ggtitle("ATAC-seq count distribution in ChIP-nexus peaks (1kb)")
)
on_samples = [s for s in samples if 'ON' in s]
feature_pairs = [(s, s.replace("ON", "OFF")) for s in samples if 'ON' in s]
sample_groups = [s.replace("ON", "").replace("ATAC_", "") for s in on_samples]
for on_sample in on_samples:
counts_on = ranges[on_sample]
counts_off = ranges[on_sample.replace("ON", "OFF")]
counts_on_pc = np.percentile(counts_on, 10)
counts_off_pc = np.percentile(counts_off, 10)
fc = np.log10((counts_on + counts_on_pc) / (counts_off + counts_off_pc))
ranges['log10_fc_' + on_sample.replace("ON", "").replace("ATAC_", "")] = fc
ranges[['log10_fc_' + s for s in sample_groups]].melt()
(ggplot(aes(x='value'), ranges[['log10_fc_' + s for s in sample_groups]].melt())
+ geom_histogram(bins=100)
+ theme_classic()
+ plt9_horizontal_facet_label()
+ plt9_remove_facet_label_box()
+ facet_grid("variable ~ .")
+ xlab("log10 fold change (ON / OFF)")
)
fold_change ~ beta_TF1 * (alt_binding_TF1 - ref_binding_TF1) + ...
from scipy.stats import pearsonr, spearmanr
df_cor_list = []
for sample_group in sample_groups:
for task in tasks:
y_true = np.log10(ranges[task]+1)
y_pred = ranges['log10_fc_' + sample_group]
spearman, spearman_pval = spearmanr(y_true, y_pred)
df_cor_list.append(dict(sample_group=sample_group, task=task, spearman=spearman))
df_cor = pd.DataFrame(df_cor_list)
from basepair.plot.utils import plt9_tilt_xlab
plotnine.options.figure_size = get_figsize(.7, 1/3)
(ggplot(aes(x='task', y='spearman'), df_cor)
+ geom_bar(stat='identity', position='dodge')
+ facet_grid(".~sample_group")
+ scale_fill_brewer('qual', 4)
+ theme_classic()
+ plt9_tilt_xlab()
+ ggtitle("Spearman correlation: ATAC_ON/OFF_fc ~ ChIP-nexus")
)
Counts in central 200 bp for each TF?
pred_dict.keys()
pred_counts = {k: {task: v.sum(axis=(1,2)) for task,v in d.items()}
for k,d in pred_dict.items()}
pred_max_profile = {k: {task: (v / v.sum(axis=1, keepdims=True)).max(axis=1).mean(axis=-1) for task,v in d.items()}
for k,d in pred_dict.items()}
X_bpnet_dsox2 = pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['dsox2'])
X_bpnet_doct4 = pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['doct4'])
main_motifs
motifs
dsox2_motifs
doct4_motifs
X_motif_counts = ranges[['motif_counts_' + motif for motif in motifs]]
X_motif_counts.columns = [c.replace('motif_counts_', '') for c in X_motif_counts.columns]
X_motif_counts
X_motif_counts_dsox2 = X_motif_counts[dsox2_motifs] # wt - ref -> all the motifs cancel out except the key {dsox2_motifs}
X_motif_counts_doct4 = X_motif_counts[doct4_motifs]
ranges
Y = ranges[['log10_fc_' + g for g in sample_groups]]
Y
# Fit linear model
from basepair.datasets import *
from basepair.config import valid_chr, test_chr
from kipoi.data_utils import numpy_collate
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_predict
from basepair.plot.evaluate import regression_eval
train = ~ranges.chrom.isin(valid_chr+test_chr)
valid = ranges.chrom.isin(valid_chr)
test = ranges.chrom.isin(test_chr)
X_motif_counts
# TODO - setup the statistical test
# - does including motif pairs help?
# - what is the coefficient sign for the main motifs
X_feat_pairs_short_range = (dfab[(dfab.center_diff > 10) & (dfab.center_diff <= 30)]
.groupby(['example_idx', 'motif_pair']).size().unstack(fill_value=0).reset_index())
X_feat_pairs_short_range = pd.merge(ranges[['example_idx']], X_feat_pairs_short_range, on='example_idx', how='left').fillna(0)
del X_feat_pairs_short_range['example_idx']
# 10% of motif instances have at least one motif_pair
(X_feat_pairs_short_range > 0).any(axis=1).mean()
X_feat_pairs_proximal = (dfab[(dfab.center_diff > 10) & (dfab.center_diff <=150)]
.groupby(['example_idx', 'motif_pair']).size().unstack(fill_value=0).reset_index())
X_feat_pairs_proximal = pd.merge(ranges[['example_idx']], X_feat_pairs_proximal, on='example_idx', how='left').fillna(0)
del X_feat_pairs_proximal['example_idx']
oct4_involved_pairs = [c for c in X_feat_pairs_proximal.columns if 'Oct4' in c]
sox2_involved_pairs = [c for c in X_feat_pairs_proximal.columns if 'Sox2' in c]
import statsmodels.api as sm
from basepair.stats import tidy_ols
pval_threshold = 0.01
oct4_involved_pairs = [c for c in X_feat_pairs_proximal.columns if 'Oct4' in c]
sox2_involved_pairs = [c for c in X_feat_pairs_proximal.columns if 'Sox2' in c]
def standardize(x):
return (x - x.mean()) / (x.std() + 0.01)
def binarize(x):
return (x > 0).astype(float)
def transform(x):
#return sm.add_constant(x)#(x > 0).astype(float)
return sm.add_constant(standardize(binarize(x)))
X_sox2 = pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal[sox2_involved_pairs]], axis=1)
X_oct4 = pd.concat([X_motif_counts[doct4_motifs], X_feat_pairs_proximal[oct4_involved_pairs]], axis=1)
oct4_subset = ranges.overlaps_ATAC_OCT4ON_S2iL
sox2_subset = ranges.overlaps_ATAC_SOX2ON_26h
print("Nucleosome short range [10,150]")
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], transform(X_sox2[sox2_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], transform(X_oct4[oct4_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
X_sox2 = pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_short_range[sox2_involved_pairs]], axis=1)
X_oct4 = pd.concat([X_motif_counts[doct4_motifs], X_feat_pairs_short_range[oct4_involved_pairs]], axis=1)
print("="*80)
print("="*80)
print("Only short range [10,30]")
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], transform(X_sox2[sox2_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], transform(X_oct4[oct4_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
import statsmodels.api as sm
from basepair.stats import tidy_ols
X_sox2 = pd.concat([X_motif_counts, X_feat_pairs_proximal], axis=1)
X_oct4 = pd.concat([X_motif_counts, X_feat_pairs_proximal], axis=1)
oct4_subset = ranges.overlaps_ATAC_OCT4ON_S2iL
sox2_subset = ranges.overlaps_ATAC_SOX2ON_26h
print("Nucleosome short range [10,150]")
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], transform(X_sox2[sox2_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], transform(X_oct4[oct4_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
X_sox2 = pd.concat([X_motif_counts, X_feat_pairs_short_range], axis=1)
X_oct4 = pd.concat([X_motif_counts, X_feat_pairs_short_range], axis=1)
print("="*80)
print("="*80)
print("Only short range [10,30]")
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], transform(X_sox2[sox2_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], transform(X_oct4[oct4_subset])).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
fig = plt.figure(figsize=get_figsize(.5, 1))
sns.heatmap(standardize(np.log(pd.DataFrame(pred_counts['wt']))).corr(), annot=True, square=True)
plt.title("WT sequence")
fig = plt.figure(figsize=get_figsize(.5, 1))
sns.heatmap(standardize(np.log(pd.DataFrame(pred_counts['random']))).corr(), annot=True, square=True)
plt.title("Random shuffled sequence");
# Model prediction analysis
import statsmodels.api as sm
from basepair.stats import tidy_ols
pval_threshold = 0.05
def transform(x):
#return x #(x > 0).astype(float)
return standardize(x)
oct4_subset = ranges.overlaps_ATAC_OCT4ON_S2iL
sox2_subset = ranges.overlaps_ATAC_SOX2ON_26h
print("="*80)
print("log(d<TF> / WT)")
X_sox2 = np.log(pd.DataFrame(pred_counts['dsox2']) / pd.DataFrame(pred_counts['wt'])) # / pd.DataFrame(pred_counts['random']))
X_oct4 = np.log(pd.DataFrame(pred_counts['doct4']) / pd.DataFrame(pred_counts['wt'])) # / pd.DataFrame(pred_counts['random']))
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], sm.add_constant(transform(X_sox2[sox2_subset]))).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], sm.add_constant(transform(X_oct4[oct4_subset]))).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("="*80)
print("WT")
X_sox2 = np.log(pd.DataFrame(pred_counts['wt'])) # / pd.DataFrame(pred_counts['random']))
X_oct4 = np.log(pd.DataFrame(pred_counts['wt'])) # / pd.DataFrame(pred_counts['random']))
print("Sox2")
# mutant - wt
results = sm.OLS(-Y['log10_fc_SOX2_26h'][sox2_subset], sm.add_constant(transform(X_sox2[sox2_subset]))).fit()
betas = tidy_ols(results).sort_values('coef')
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
print("="*80)
print("Oct4")
results = sm.OLS(-Y['log10_fc_OCT4_S2iL'][oct4_subset], sm.add_constant(transform(X_oct4[oct4_subset]))).fit()
betas = tidy_ols(results).sort_values('coef')
betas
print(betas[betas['P>|t|'].astype(float) < pval_threshold].to_string())
#bottleneck = m.seqmodel.bottleneck_model().predict(seqs)
bottleneck_dsox2 = m.seqmodel.bottleneck_model().predict(seq_dict['dsox2'])
bottleneck_doct4 = m.seqmodel.bottleneck_model().predict(seq_dict['doct4'])
bottleneck_features_dsox2 = pd.DataFrame(bottleneck_dsox2.mean(axis=1))
bottleneck_features_doct4 = pd.DataFrame(bottleneck_doct4.mean(axis=1))
bottleneck_features = pd.DataFrame(bottleneck.mean(axis=1))
datasets = {"dSox2": {"y": Y['log10_fc_SOX2_26h'],
"subset": ranges.overlaps_ATAC_SOX2ON_26h,
"x_list": {
'BPNet_bottleneck': bottleneck_features,
'BPNet_bottleneck_dsox2': bottleneck_features - bottleneck_features_dsox2,
'BPNet_diff': pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['dsox2']),
'BPNet_log_fc': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['dsox2'])),
'BPNet_log_fc_random': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['random'])),
'BPNet_maxref_fc': np.log(pd.DataFrame(pred_max_profile['wt']) / pd.DataFrame(pred_max_profile['dsox2'])),
'BPNet_maxref_diff': pd.DataFrame(pred_max_profile['wt']) - pd.DataFrame(pred_max_profile['dsox2']),
'BPNet_log_diff': np.log(pd.DataFrame(pred_counts['wt'])) - np.log(pd.DataFrame(pred_counts['dsox2'])),
'motif_counts': X_motif_counts[dsox2_motifs],
'motif_counts+pairs': pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal[sox2_involved_pairs]], axis=1),
'motif_counts+pairs-binary': binarize(pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal[sox2_involved_pairs]], axis=1)),
'motif_counts+all_pairs': pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal], axis=1),
'BPNet_wt_log': np.log(pd.DataFrame(pred_counts['wt'])),
'wt_motif_counts': X_motif_counts,
'wt_motif_counts+pairs': pd.concat([X_motif_counts, X_feat_pairs_proximal], axis=1),
}},
"dSox2_40h": {"y": Y['log10_fc_SOX2_40h'],
"subset": ranges.overlaps_ATAC_SOX2ON_26h,
"x_list": {
'BPNet_bottleneck': bottleneck_features,
'BPNet_bottleneck_dsox2': bottleneck_features - bottleneck_features_dsox2,
'BPNet_diff': pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['dsox2']),
'BPNet_log_fc': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['dsox2'])),
'BPNet_log_fc_random': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['random'])),
'BPNet_maxref_fc': np.log(pd.DataFrame(pred_max_profile['wt']) / pd.DataFrame(pred_max_profile['dsox2'])),
'BPNet_maxref_diff': pd.DataFrame(pred_max_profile['wt']) - pd.DataFrame(pred_max_profile['dsox2']),
'BPNet_log_diff': np.log(pd.DataFrame(pred_counts['wt'])) - np.log(pd.DataFrame(pred_counts['dsox2'])),
'motif_counts': X_motif_counts[dsox2_motifs],
'motif_counts+pairs': pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal[sox2_involved_pairs]], axis=1),
'motif_counts+pairs-binary': binarize(pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal[sox2_involved_pairs]], axis=1)),
'motif_counts+all_pairs': pd.concat([X_motif_counts[dsox2_motifs], X_feat_pairs_proximal], axis=1),
'BPNet_wt_log': np.log(pd.DataFrame(pred_counts['wt'])),
'wt_motif_counts': X_motif_counts,
'wt_motif_counts+pairs': pd.concat([X_motif_counts, X_feat_pairs_proximal], axis=1),
}},
"dOct4": {"y": Y['log10_fc_OCT4_S2iL'],
"subset": ranges.overlaps_ATAC_OCT4ON_S2iL,
"x_list": {
'BPNet_bottleneck': bottleneck_features,
'BPNet_bottleneck_doct4': bottleneck_features - bottleneck_features_doct4,
'BPNet_diff': pd.DataFrame(pred_counts['wt']) - pd.DataFrame(pred_counts['doct4']),
'BPNet_log_fc': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['doct4'])),
'BPNet_log_fc_random': np.log(pd.DataFrame(pred_counts['wt']) / pd.DataFrame(pred_counts['random'])),
'BPNet_log_diff': np.log(pd.DataFrame(pred_counts['wt'])) - np.log(pd.DataFrame(pred_counts['doct4'])),
'BPNet_maxref_diff': pd.DataFrame(pred_max_profile['wt']) - pd.DataFrame(pred_max_profile['doct4']),
'BPNet_maxref_fc': np.log(pd.DataFrame(pred_max_profile['wt']) / pd.DataFrame(pred_max_profile['doct4'])),
'motif_counts': X_motif_counts[doct4_motifs],
'motif_counts+pairs': pd.concat([X_motif_counts[doct4_motifs], X_feat_pairs_proximal[oct4_involved_pairs]], axis=1),
'motif_counts+pairs-binary': binarize(pd.concat([X_motif_counts[doct4_motifs], X_feat_pairs_proximal[oct4_involved_pairs]], axis=1)),
'motif_counts+all_pairs': pd.concat([X_motif_counts[doct4_motifs], X_feat_pairs_proximal], axis=1),
'BPNet_wt_log': np.log(pd.DataFrame(pred_counts['wt'])),
'wt_motif_counts': X_motif_counts,
'wt_motif_counts+pairs': pd.concat([X_motif_counts, X_feat_pairs_proximal], axis=1),
}},
}
for dataset,v in datasets.items():
print(dataset)
y = v['y']
for method, X in v['x_list'].items():
model = LinearRegression()
# m = RandomForestRegressor(n_estimators=100, n_jobs=5)
model.fit(X[train & v['subset']].values, y[train & v['subset']].values)
fig, ax = plt.subplots(figsize=(3, 3))
regression_eval(y[test & v['subset']], model.predict(X[test & v['subset']]), task=f'{dataset} {method}', alpha=0.5, ax=ax);
for dataset,v in datasets.items():
print(dataset)
y = v['y']
for method, X in v['x_list'].items():
# model = LinearRegression()
model = RandomForestRegressor(n_estimators=100, n_jobs=5)
model.fit(X[train].values, y[train].values)
fig, ax = plt.subplots(figsize=(3, 3))
regression_eval(y[test], model.predict(X[test]), task=f'{dataset} {method}', alpha=0.5, ax=ax);