visualize the following things regarding bias correction:
this will give you a clue how strongly the bias is present
and whether it has a linear relationship with the output
our model is namely: y ~ model_output + bias * w2
and if the relationship would be non-linear we'd need to do:
y ~ model_output + f(bias)
# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml'
from basepair.datasets import get_StrandedProfile_datasets2
train, valid = get_StrandedProfile_datasets2(
dataspec=dataspec_path,
peak_width = 1000,
seq_width=1000,
include_metadata = False,
taskname_first = True, # so that the output labels will be "{task}/profile"
exclude_chr = ['chrX', 'chrY'])
valid = valid[0][1]
ds = DataSpec.load(dataspec_path)
tasks = list(ds.task_specs)
train_set = train.load_all(batch_size=32, num_workers=5)
train_set['targets']['Oct4/profile'].shape
train_set['inputs']['bias/Oct4/profile'].shape
import warnings
warnings.filterwarnings('ignore')
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
for task in tasks:
total_counts = np.sum(train_set['targets'][f'{task}/profile'], axis=(1, 2), dtype=np.float32)
bias_total_counts = np.sum(train_set['inputs'][f'bias/{task}/profile'], axis=(1, 2), dtype=np.float32)
plt.figure(figsize=(20,10))
matplotlib.rcParams.update({'font.size': 32})
plt.scatter(total_counts, bias_total_counts, s=10, c="r", alpha=0.5, marker='x')
plt.xlabel("total_counts", fontsize=20)
plt.ylabel("bias_total_counts", fontsize=20)
cc, p = spearmanr(total_counts, bias_total_counts)
cc = round(cc, 5)
p = round(p, 5)
plt.title(f'ChIP-nexus total counts for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
plt.show()
from basepair.preproc import bin_counts
for task in tasks:
log_local_counts = np.log10(np.sum(bin_counts(train_set['targets'][f'{task}/profile'],
binsize=50), axis = 2, dtype=np.float32).flatten())
log_bias_local_counts = np.log10(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
binsize=50), axis = 2, dtype=np.float32).flatten())
plt.figure(figsize=(20,10))
matplotlib.rcParams.update({'font.size': 32})
plt.scatter(log_local_counts, log_bias_local_counts, s=10, c="y", alpha=0.5, marker='x')
plt.xlabel("log_local_counts", fontsize=20)
plt.ylabel("log_bias_local_counts", fontsize=20)
cc, p = spearmanr(log_local_counts, log_bias_local_counts)
cc = round(cc, 5)
p = round(p, 5)
plt.title(f'ChIP-nexus local counts for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
plt.show()