import basepair
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
# Use gpus 1, 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 3, 5"
ddir = get_data_dir()
ddir
bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
pos_counts=f"{bdir}/Sox2/pos.bw",
neg_counts=f"{bdir}/Sox2/neg.bw",
peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
),
"Oct4": TaskSpec(task="Oct4",
pos_counts=f"{bdir}/Oct4/pos2.bw",
neg_counts=f"{bdir}/Oct4/neg2.bw",
peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
)
},
fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
)
def ds2bws(ds):
return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
# Get the training data
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
# Pre-process the data
preproc = AppendTotalCounts()
preproc.fit(train[1])
train[1] = preproc.transform(train[1])
valid[1] = preproc.transform(valid[1])
test[1] = preproc.transform(test[1])
train[1].keys()
# TODO - play around with this
def count_predict(filters=32,
conv1_kernel_size=21,
lr=0.004,
c_task_weight=100,
tasks=['sox2', 'oct4'],
dropout=0.1,
seq_len=201,
pool_type='avg'):
def get_pool(pool_type):
if pool_type=='max':
return kl.MaxPool1D()
elif pool_type =='avg':
return kl.AveragePooling1D()
inp = kl.Input(shape=(seq_len, 4), name='seq')
x = kl.Conv1D(filters,
kernel_size=conv1_kernel_size,
padding='same',
activation='relu')(inp)
x = kl.Conv1D(filters, 1, activation='relu', padding='same')(x)
x = get_pool(pool_type)(x)
x = kl.Conv1D(2*filters, 7, activation='relu', padding='same')(x)
x = get_pool(pool_type)(x)
x = kl.Conv1D(2*filters, 7, activation='relu', padding='same')(x)
x = kl.GlobalAveragePooling1D()(x)
x = kl.Dense(8*filters, activation='relu')(x)
x = kl.Dropout(dropout)(x)
#x = kl.Dense(1)(x)
#out = kl.Reshape((-1, 2 * len(tasks)))(x)
# setup the output branches
outputs = []
losses = []
loss_weights = []
counts = [kl.Dense(2, name="counts/" + task)(x)
for task in tasks]
outputs += counts
losses += ["mse"] * len(tasks)
loss_weights += [c_task_weight] * len(tasks)
model = Model(inp, outputs)
model.compile(Adam(lr=lr),
loss=losses, loss_weights=loss_weights)
return model
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
i=1
def get_model(mfn, mkwargs, fkwargs, i):
"""Get the model"""
import datetime
mdir = f"{ddir}/processed/chipseq/exp/models/count"
name = mfn + "_" + \
",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
"." + str(i)
i+=1
!mkdir -p {mdir}
ckp_file = f"{mdir}/{name}.h5"
all_kwargs = {**mkwargs, **fkwargs}
return eval(mfn)(**all_kwargs), name, ckp_file
# hyper-parameters
mfn = "count_predict"
mkwargs = dict(filters=32,
conv1_kernel_size=21,
c_task_weight=10,
seq_len=train[0].shape[1],
lr=0.0004)
fixed_kwargs = dict(
tasks=list(ds.task_specs)
)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0],
train[1],
batch_size=256,
epochs=100,
validation_data=valid[:2],
callbacks=[EarlyStopping(patience=10),
History(),
ModelCheckpoint(ckp_file, save_best_only=True)]
)
# get the best model
model = load_model(ckp_file)
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
from basepair.math import softmax
from basepair import samplers
from basepair.preproc import bin_counts
import numpy as np
class Seq2Sox2Oct4:
def __init__(self, x, y, model):
self.x = x
self.y = y
self.model = model
# Make the prediction
self.y_pred = [softmax(y) for y in model.predict(x)]
def plot(self, n=10, kind='test', sort='random', figsize=(20, 2), fpath_template=None, binsize=1):
import matplotlib.pyplot as plt
if sort == 'random':
idx_list = samplers.random(self.x, n)
elif "_" in sort:
kind, task = sort.split("_")
#task_id = {"Sox2": 0, "Oct4": 1}[task]
if kind == "max":
idx_list = samplers.top_max_count(self.y[f"profile/{task}"], n)
elif kind == "sum":
idx_list = samplers.top_sum_count(self.y[f"profile/{task}"], n)
else:
raise ValueError("")
else:
raise ValueError(f"sort={sort} couldn't be interpreted")
# for visualization, we use bucketize
for i, idx in enumerate(idx_list):
fig = plt.figure(figsize=figsize)
plt.subplot(141)
if i == 0:
plt.title("Predicted Sox2")
plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[0][idx, :, 0])))
plt.plot(bin_counts(self.y_pred[0], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[0][idx, :, 1])))
plt.legend()
plt.subplot(142)
if i == 0:
plt.title("Observed Sox2")
plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 0])))
plt.plot(bin_counts(self.y["profile/Sox2"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Sox2"][idx, :, 1])))
plt.legend()
plt.subplot(143)
if i == 0:
plt.title("Predicted Oct4")
plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y_pred[1][idx, :, 0])))
plt.plot(bin_counts(self.y_pred[1], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y_pred[1][idx, :, 1])))
plt.legend()
plt.subplot(144)
if i == 0:
plt.title("Observed Oct4")
plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 0], label='pos,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 0])))
plt.plot(bin_counts(self.y["profile/Oct4"], binsize=binsize)[idx, :, 1], label='neg,m={}'.format(np.argmax(self.y["profile/Oct4"][idx, :, 1])))
plt.legend()
if fpath_template is not None:
plt.savefig(fpath_template.format(i) + '.png', dpi=600)
plt.savefig(fpath_template.format(i) + '.pdf', dpi=600)
plt.close(fig) # close the figure
show_figure(fig)
plt.show()
y_pred = model.predict(test[0])
from basepair.plots import regression_eval
regression_eval(test[1]['counts/Sox2'].mean(-1), y_pred[ds.task2idx("Sox2", 'counts')-2].mean(-1))
regression_eval(test[1]['counts/Oct4'].mean(-1), y_pred[ds.task2idx("Oct4", 'counts')-2].mean(-1))
test[2].head()