how are model predictions correlated with the bias?
scatter-plot:
basically we hope that our model output doesn't correlate with the bias
# Imports
from basepair.imports import *
import warnings
warnings.filterwarnings('ignore')
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-seq.dataspec.yml'
from basepair.datasets import get_StrandedProfile_datasets2
train, valid = get_StrandedProfile_datasets2(
dataspec=dataspec_path,
peak_width = 1000,
seq_width=1000,
include_metadata = False,
taskname_first = True, # so that the output labels will be "{task}/profile"
exclude_chr = ['chrX', 'chrY'],
profile_bias_pool_size=None)
valid = valid[0][1]
!cat {dataspec_path}
ds = DataSpec.load(dataspec_path)
tasks = list(ds.task_specs)
tasks
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
from basepair.plot.evaluate import regression_eval
m = multihead_seq_model(tasks=tasks,
filters=64,
n_dil_layers=6,
conv1_kernel_size=25,tconv_kernel_size=50,
b_loss_weight=0, c_loss_weight=10, p_loss_weight=1,
use_bias=True,
lr=0.004, padding='same', batchnorm=False)
output_dir='/tmp/exp/m3-wb'
!mkdir -p {output_dir}
!rm /tmp/exp/m3-wb/model.h5
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
tr.train(epochs=100)
#%debug
eval_metrics = tr.evaluate(metric=None) # metric=None -> uses the default head metrics
print(eval_metrics)
import pybedtools
from basepair.utils import flatten_list
paper_config()
seq_model = tr.seq_model # extract the seq_model
from genomelake.extractors import FastaExtractor
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
it = valid.batch_iter(batch_size=1, shuffle=True)
batch = next(it)
seq = batch['inputs']['seq']
seq.shape
seq_model.all_heads['Oct4'][0].use_bias
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
imp_scores.keys()
imp_scores['Oct4/profile/wn'].shape
preds = seq_model.predict_preact(seq)
seq.shape
viz_dict = OrderedDict(flatten_list([[
(f"{task} Obs", obs[task]),
(f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
] for task_idx, task in enumerate(tasks)]))
viz_dict = filter_tracks(viz_dict, [420, 575])
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
for feature in ['Imp profile', 'Obs']}
ylim = []
for k in viz_dict:
f = k.split(" ", 1)[1]
if "Imp" in f:
ylim.append((fmin[f], fmax[f]))
else:
ylim.append((0, fmax[f]))
fig = plot_tracks(viz_dict,
#seqlets=shifted_seqlets,
title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
fig_height_per_track=2,
rotate_y=0,
fig_width=20,
ylim=ylim,
legend=False)
train_set = train.load_all(batch_size=32, num_workers=5)
import matplotlib
%matplotlib inline
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
binsize=50
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
fig, axes = plt.subplots(2, len(tasks), figsize=get_figsize(2/4*len(tasks), 2/len(tasks)))
for i, task in enumerate(train.tasks):
preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
axis=(1, 2), dtype=np.float32))
cc, p = spearmanr(preds_for_total, log_bias_total_counts)
cc = "{0:.2f}".format(cc)
ax = axes[0, i]
matplotlib.rcParams.update({'font.size': 32})
ax.scatter(preds_for_total, log_bias_total_counts, s=10, c="b", alpha=0.5, marker='o',
label=f'Rs={cc}')
ax.legend()
ax.set_xlabel("preds_for_total")
if i == 0:
ax.set_ylabel("log_bias_total_counts")
else:
ax.set_ylabel("")
ax.set_title(task)
preds_for_local = np.ravel(np.sum(bin_counts(train_preds[f'{task}/profile'],
binsize=binsize), axis=-1, dtype=np.float32))
log_bias_local_counts = np.log10(np.ravel(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
binsize=binsize), axis=-1, dtype=np.float32)))
cc, p = spearmanr(preds_for_local, log_bias_local_counts)
cc = "{0:.2f}".format(cc)
ax = axes[1, i]
ax.scatter(preds_for_local, log_bias_local_counts, s=10, c="b", alpha=0.5, marker='o',
label=f'Rs={cc}')
ax.legend()
ax.set_xlabel("preds_for_local")
if i == 0:
ax.set_ylabel("log_bias_local_counts")
else:
ax.set_ylabel("")