Browser screenshot: https://epigenomegateway.wustl.edu/browser/?genome=mm10&session=KuA8LTv0cU&statusId=714272935
from basepair.config import get_data_dir, create_tf_session
create_tf_session(1)
mdir = "/users/avsec/workspace/basepair-workflow/models/0/"
ls {mdir}
ls {mdir}/modisco/test/Sox2/profile
import numpy as np
import matplotlib.pyplot as plt
for protein in ['Sox2', 'Oct4']:
for task in ['profile', 'counts']:
for split in ['valid', 'test']:
fig = plt.figure(figsize=(4, 1))
distances = np.load(f"{mdir}/modisco/{split}/{protein}/{task}/distances.npy")
plt.hist(distances, bins=50);
plt.title(f"{protein} {split} {task}")
from basepair.utils import write_pkl, read_pkl
train, valid, test = read_pkl(f"{mdir}/data.pkl")
len(train[0])
train[2].task.value_counts()
valid[2].task.value_counts()
test[2].task.value_counts()
from keras.models import load_model
model = load_model(f"{mdir}/model.h5")
from basepair.plots import Seq2Nexus
from basepair.math import softmax
from basepair.plots import *
class Seq2Nexus:
def __init__(self, x, y, df, model, tasks=['sox2', 'oct4']):
self.x = x
self.y = y
self.df = df
self.labels = df.chr + ":" + df.start.astype(str) + "-" + df.end.astype(str)
self.model = model
# Make the prediction
self.tasks = tasks
self.y_pred = [softmax(p) for p in model.predict(x)]
self.seq_len = self.y_pred[0].shape[1]
self.t2i = {k:i for i,k in enumerate(self.tasks)}
def input_grad(self, x, strand='pos', task_id=0, seq_grad='max'):
strand_id = {"pos":0, "neg":1}[strand]
if seq_grad == 'count':
inp = self.model.inputs[0]
fn = K.function([inp], K.gradients(self.model.outputs[len(self.tasks)+task_id][:,strand_id], inp))
return fn([x])[0]
if seq_grad =='max':
sfn = K.max
elif seq_grad == 'mean':
sfn = K.mean
else:
raise ValueError(f"seq_grad={seq_grad} couldn't be interpreted")
inp = self.model.inputs[0]
fn = K.function([inp], K.gradients(sfn(self.model.outputs[task_id][:,:,strand_id], axis=-1), inp))
return fn([x])[0]
def plot(self, n=10, kind='test', sort='random',
seq_grad='max', figsize=(20,6)):
import matplotlib.pyplot as plt
if sort=='random':
idx_list = samplers.random(self.x, n)
elif "_" in sort:
kind, task = sort.split("_")
task_id = self.t2i[task]
if kind == "max":
idx_list = samplers.top_max_count(self.y["profile/" + task], n)
elif kind == "sum":
idx_list = samplers.top_sum_count(self.y["profile/" + task], n)
else:
raise ValueError(f"sort={sort} couldn't be interpreted")
# compute grads
grads = [[self.input_grad(self.x[idx_list], 'pos', i, seq_grad) * self.x[idx_list],
self.input_grad(self.x[idx_list], 'neg', i, seq_grad) * self.x[idx_list],
self.input_grad(self.x[idx_list], 'pos', i, "count") * self.x[idx_list],
self.input_grad(self.x[idx_list], 'neg', i, "count") * self.x[idx_list]]
for i in range(len(self.tasks))]
for i,idx in enumerate(idx_list):
n = 6
fig, axes = plt.subplots(n*len(self.tasks), 1, sharex=True, figsize=figsize)
for tid, task in enumerate(self.tasks):
axes[0 + n*tid].plot(np.arange(1,self.seq_len+1), self.y["profile/" + task][idx,:,0], label="pos")#
axes[0 + n*tid].plot(np.arange(1,self.seq_len+1), self.y["profile/" + task][idx,:,1], label="neg")
axes[0 + n*tid].set_ylabel("Observed")
axes[0 + n*tid].legend()
axes[0 + n*tid].set_title('{} {}'.format(task, self.labels.iloc[idx]))
axes[1 + n*tid].plot(np.arange(1,self.seq_len+1), self.y_pred[tid][idx,:,0], label="pos")#
axes[1 + n*tid].plot(np.arange(1,self.seq_len+1), self.y_pred[tid][idx,:,1], label="neg")
axes[1 + n*tid].set_ylabel("Predicted")
axes[1 + n*tid].legend()
# ------------------
seqlogo(grads[tid][0][i], ax=axes[2 + n*tid]);
axes[2 + n*tid].set_ylabel("Pos. strand")
seqlogo(grads[tid][1][i], ax=axes[3 + n*tid]);
axes[3 + n*tid].set_ylabel("Neg. strand")
seqlogo(grads[tid][2][i], ax=axes[4 + n*tid]);
axes[4 + n*tid].set_ylabel("Counts Pos")
seqlogo(grads[tid][3][i], ax=axes[5 + n*tid]);
axes[5 + n*tid].set_ylabel("Counts Neg")
x_range = [1, self.seq_len]
axes[5 + n*tid].set_xticks(list(range(0, self.seq_len, 5)));
sn = Seq2Nexus(test[0], test[1], test[2], model, ["Sox2", "Oct4"])
sn.plot(sort='max_Oct4', figsize=(20,12))
sn.plot(sort='max_Sox2', figsize=(20,12))
from basepair.BPNet import BPNetPredictor
bp = BPNetPredictor(model,
fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta",
tasks=["Sox2", "Oct4"],
preproc=preproc)
from pysam import FastaFile
fa = FastaFile("/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta")
fa.close()
import pandas as pd
dfm = pd.concat([train[2], valid[2], test[2]])
bt = BedTool.from_dataframe(dfm.sort_values(["chr", 'start'])[['chr', 'start', 'end']])
dfm = dfm.sort_values(["chr", 'start'])
genome = [(c, l) for c, l in zip(fa.references, fa.lengths)
if c in list(dfm['chr'].unique())]
list(dfm['chr'].unique())
from basepair.cli.export_bw import export_bigwigs
export_bigwigs(mdir, f"{mdir}/bw/", profile_grad='max', batch_size=32, gpu=3)
list(zip(fa.references, fa.lengths))
bt = BedTool.from_dataframe(train[2][['chr', 'start', 'end']])
bt[0]
a = bp.predict([bt[0]])
a[0]['scale_profile']
bp.predict_plot([bt[0]])
ls {mdir}
preproc = read_pkl(f"{mdir}/preprocessor.pkl")
scaler = preproc.objects['profile/Oct4'].steps[1][1]
scaler.mean_
scaler.scale_
from pybedtools import BedTool
from basepair.datasets import BED_DIR
from basepair.cli.schemas import DataSpec
ds = DataSpec.load(f"{mdir}/dataspec.yaml")
ddir = get_data_dir()
orig_peak = f"{BED_DIR}/Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.narrowPeak"
genome_file = "/mnt/data/pipeline_genome_data/mm10/mm10.chrom.sizes "
!mkdir -p {ddir}/processed/chipnexus/basepair-workflow
!bedtools slop -i {ds.task_specs['Sox2'].peaks} -g {genome_file} -l 100 -r 100 > \
{ddir}/processed/chipnexus/basepair-workflow/sox2.peaks.bed
!head {ddir}/processed/chipnexus/basepair-workflow/sox2.peaks.bed
new_sox2 = BedTool(f"{ddir}/processed/chipnexus/basepair-workflow/sox2.peaks.bed")
old_sox2 = BedTool(orig_peak)
len(new_sox2)
len(new_sox2.intersect(old_sox2, wa=True, u=True))
len(old_sox2)
len(old_sox2.intersect(new_sox2, wa=True, u=True))
The old Sox2 had 2x the number of peaks. Intersection is pretty high.
!wc -l {ddir}/processed/chipnexus/basepair-workflow/sox2.peaks.bed
ls
!cat {mdir}/dataspec.yaml
ds
!head {orig_peak}