Goals

  • correct-for biases in ChIP-nexus

how are model predictions correlated with the bias?

scatter-plot:

  • model predictions for total counts (before adding the bias term) vs bias total counts (log scale only the latter)
  • model predictions for local counts before adding the bias term (in say 50bp bins) {here you bin_counts the predictions} vs bias local counts (log scale only the latter)

basically we hope that our model output doesn't correlate with the bias

In [ ]:
# Imports
from basepair.imports import *
hv.extension('bokeh')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
paper_config()
In [ ]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5"
In [ ]:
dataspec_path = '/users/amr1/basepair/src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml'
In [ ]:
from basepair.datasets import get_StrandedProfile_datasets2

train, valid = get_StrandedProfile_datasets2(
    dataspec=dataspec_path,
    peak_width = 1000,
    seq_width=1000,
    include_metadata = False,
    taskname_first = True,  # so that the output labels will be "{task}/profile"
    exclude_chr = ['chrX', 'chrY'],
    profile_bias_pool_size=None) # TODO: try profile_bias_pool_size = 50
In [ ]:
valid = valid[0][1]
In [ ]:
!cat {dataspec_path}
In [ ]:
ds = DataSpec.load(dataspec_path)
In [ ]:
tasks = list(ds.task_specs)
In [ ]:
tasks
In [ ]:
from basepair.trainers import SeqModelTrainer
from basepair.models import multihead_seq_model
In [ ]:
m = multihead_seq_model(tasks=tasks,
                        filters=64,
                        n_dil_layers=9,
                        conv1_kernel_size=25,tconv_kernel_size=25, 
                        b_loss_weight=0, c_loss_weight=10, p_loss_weight=1, 
                        use_bias=True,
                        lr=0.004, padding='same', batchnorm=False)
In [ ]:
output_dir='/tmp/exp/m3-wb'
In [ ]:
!mkdir -p {output_dir}
In [ ]:
!rm /tmp/exp/m3-wb/model.h5
In [ ]:
tr = SeqModelTrainer(m, train, valid, output_dir=output_dir)
In [ ]:
tr.train(epochs=100)
In [ ]:
#%debug
In [ ]:
eval_metrics = tr.evaluate(metric=None)  # metric=None -> uses the default head metrics
print(eval_metrics)

Get the importance scores

In [ ]:
import pybedtools
from basepair.utils import flatten_list
paper_config()
In [ ]:
seq_model = tr.seq_model  # extract the seq_model
In [ ]:
from genomelake.extractors import FastaExtractor
In [ ]:
# Get data for the oct4 enhancer
interval = pybedtools.create_interval_from_list(['chr17', 35503550, 35504550])
obs = {task: ds.task_specs[task].load_counts([interval])[0] for task in tasks}
seq = FastaExtractor(ds.fasta_file)([interval])
In [ ]:
it = valid.batch_iter(batch_size=1, shuffle=True)
In [ ]:
batch = next(it)
seq = batch['inputs']['seq']
In [ ]:
seq.shape
In [ ]:
seq_model.all_heads['Oct4'][0].use_bias
In [ ]:
imp_scores = seq_model.imp_score_all(seq, batch_size=1)
In [ ]:
imp_scores.keys()
In [ ]:
imp_scores['Oct4/profile/wn'].shape
In [ ]:
preds = seq_model.predict_preact(seq)
In [ ]:
seq.shape

Pred and observed

In [ ]:
viz_dict = OrderedDict(flatten_list([[
                    (f"{task} Obs", obs[task]),
                    (f"{task} Imp profile", imp_scores[f"{task}/profile/wn"][0] * seq[0]),
                ] for task_idx, task in enumerate(tasks)]))

viz_dict = filter_tracks(viz_dict, [420, 575])
In [ ]:
fmax = {feature: max([viz_dict[f"{task} {feature}"].max() for task in tasks])
        for feature in ['Imp profile', 'Obs']}
fmin = {feature: min([viz_dict[f"{task} {feature}"].min() for task in tasks])
        for feature in ['Imp profile', 'Obs']}


ylim = []
for k in viz_dict:
    f = k.split(" ", 1)[1]
    if "Imp" in f:
        ylim.append((fmin[f], fmax[f]))
    else:
        ylim.append((0, fmax[f]))
In [ ]:
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
fig = plot_tracks(viz_dict,
                  #seqlets=shifted_seqlets,
                  title="{i.chrom}:{i.start}-{i.end}, {i.name}".format(i=interval),
                  fig_height_per_track=2,
                  rotate_y=0,
                  fig_width=20,
                  ylim=ylim,
                  legend=False)
In [ ]:
train_set = train.load_all(batch_size=32, num_workers=5)
In [ ]:
import matplotlib
from scipy.stats import spearmanr
from basepair.preproc import bin_counts
In [ ]:
train_preds = seq_model.predict_preact(train_set['inputs']['seq'])
In [ ]:
for task in tasks:
    preds_for_total = np.sum(train_preds[f'{task}/profile'], axis=(1, 2), dtype=np.float32)
    log_bias_total_counts = np.log10(np.sum(train_set['inputs'][f'bias/{task}/profile'],
                                            axis=(1, 2),  dtype=np.float32))
    plt.figure(figsize=(20,10))
    matplotlib.rcParams.update({'font.size': 32})
    plt.scatter(preds_for_total, log_bias_total_counts, s=10, c="r", alpha=0.5, marker='x')
    plt.xlabel("preds_for_total", fontsize=20)
    plt.ylabel("log_bias_total_counts", fontsize=20)
    cc, p = spearmanr(preds_for_total, log_bias_total_counts)
    cc = round(cc, 5)
    p = round(p, 5)
    plt.title(f'ChIP-nexus total counts and preds for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
    plt.show()
In [ ]:
for task in tasks:
    preds_for_local = np.sum(bin_counts(train_preds[f'{task}/profile'],
                                        binsize=50), axis = 2, dtype=np.float32).flatten()
    log_bias_local_counts = np.log10(np.sum(bin_counts(train_set['inputs'][f'bias/{task}/profile'],
                                          binsize=50), axis = 2, dtype=np.float32).flatten())
    plt.figure(figsize=(20,10))
    matplotlib.rcParams.update({'font.size': 32})
    plt.scatter(preds_for_local, log_bias_local_counts, s=10, c="y", alpha=0.5, marker='x')
    plt.xlabel("preds_for_local", fontsize=20)
    plt.ylabel("log_bias_local_counts", fontsize=20)
    cc, p = spearmanr(preds_for_local, log_bias_local_counts)
    cc = round(cc, 5)
    p = round(p, 5)
    plt.title(f'ChIP-nexus local counts and preds for {task}: spearman correlation={cc}, p-val={p}', fontsize=24)
    plt.show()