how are model predictions correlated with the bias?
scatter-plot:
basically we hope that our model output doesn't correlate with the bias
# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml'
from basepair.datasets import get_StrandedProfile_datasets2
train, valid = get_StrandedProfile_datasets2(
dataspec=dataspec_path,
peak_width = 1000,
seq_width=1000,
include_metadata = False,
taskname_first = True, # so that the output labels will be "{task}/profile"
exclude_chr = ['chrX', 'chrY'],
profile_bias_pool_size=None) # TODO: try profile_bias_pool_size = 50
valid = valid[0][1]
!cat {dataspec_path}
ds = DataSpec.load(dataspec_path)
tasks = list(ds.task_specs)
tasks
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
m = multihead_seq_model(tasks=tasks,
filters=64,
n_dil_layers=9,
conv1_kernel_size=25,tconv_kernel_size=25,
b_loss_weight=0, c_loss_weight=10, p_loss_weight=1,
use_bias=True,
lr=0.004, padding='same', batchnorm=False)
output_dir='/tmp/exp/m3-wb'
!mkdir -p {output_dir}
!rm /tmp/exp/m3-wb/model.h5
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
tr.train(epochs=100)
#%debug
eval_metrics = tr.evaluate(metric=None) # metric=None -> uses the default head metrics
print(eval_metrics)
import pybedtools
from basepair.utils import flatten_list
paper_config()
seq_model = tr.seq_model # extract the seq_model
from genomelake.extractors import FastaExtractor
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
it = valid.batch_iter(batch_size=1, shuffle=True)
batch = next(it)
seq = batch['inputs']['seq']
seq.shape
seq_model.all_heads['Oct4'][0].use_bias
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
imp_scores.keys()
imp_scores['Oct4/profile/wn'].shape
preds = seq_model.predict_preact(seq)
seq.shape
viz_dict = OrderedDict(flatten_list([[
(f"{task} Obs", obs[task]),
(f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
] for task_idx, task in enumerate(tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
for feature in ['Imp profile', 'Obs']}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=2,
rotate_y=0,
fig_width=20,
ylim=ylim,
legend=False)
train_set = train.load_all(batch_size=32, num_workers=5)
import matplotlib
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
for task in tasks:
preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
axis=(1, 2), dtype=np.float32))
plt.figure(figsize=(20,10))
matplotlib.rcParams.update({'font.size': 32})
plt.scatter(preds_for_total, log_bias_total_counts, s=10, c="r", alpha=0.5, marker='x')
plt.xlabel("preds_for_total", fontsize=20)
plt.ylabel("log_bias_total_counts", fontsize=20)
cc, p = spearmanr(preds_for_total, log_bias_total_counts)
cc = round(cc, 5)
p = round(p, 5)
plt.title(f'ChIP-nexus total counts and preds for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
plt.show()
for task in tasks:
preds_for_local = np.sum(bin_counts(train_preds[f'{task}/profile'],
binsize=50), axis = 2, dtype=np.float32).flatten()
log_bias_local_counts = np.log10(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
binsize=50), axis = 2, dtype=np.float32).flatten())
plt.figure(figsize=(20,10))
matplotlib.rcParams.update({'font.size': 32})
plt.scatter(preds_for_local, log_bias_local_counts, s=10, c="y", alpha=0.5, marker='x')
plt.xlabel("preds_for_local", fontsize=20)
plt.ylabel("log_bias_local_counts", fontsize=20)
cc, p = spearmanr(preds_for_local, log_bias_local_counts)
cc = round(cc, 5)
p = round(p, 5)
plt.title(f'ChIP-nexus local counts and preds for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
plt.show()