import basepair
import modisco
from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
# Use gpus 1, 3, 5, 7
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 3, 5, 7"
ddir = get_data_dir()
ddir
bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
pos_counts=f"{bdir}/Sox2/pos.bw",
neg_counts=f"{bdir}/Sox2/neg.bw",
peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
),
"Oct4": TaskSpec(task="Oct4",
pos_counts=f"{bdir}/Oct4/pos2.bw",
neg_counts=f"{bdir}/Oct4/neg2.bw",
peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
)
},
fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
)
def ds2bws(ds):
return {task: {"pos": task_spec.pos_counts, "neg": task_spec.neg_counts} for task, task_spec in ds.task_specs.items()}
# Get the training data
train, valid, test = chip_exo_nexus(ds, peak_width=1000)
# Pre-process the data
preproc = AppendTotalCounts()
preproc.fit(train[1])
train[1] = preproc.transform(train[1])
valid[1] = preproc.transform(valid[1])
test[1] = preproc.transform(test[1])
train[1].keys()
tasks=['Sox2', 'Oct4']
# TODO - play around with this
def seq_multitask_chipseq(filters=21,
conv1_kernel_size=21,
tconv_kernel_size=25,
#tconv_kernel_size2=25,
n_dil_layers=6,
lr=0.004,
c_task_weight=100,
use_profile=True,
use_counts=True,
tasks=tasks,
seq_len=201):
"""
Dense
Args:
c_task_weights: how to upweight the count-prediction task
"""
# TODO - build the reverse complement symmetry into the model
inp = kl.Input(shape=(seq_len, 4), name='seq')
first_conv = kl.Conv1D(filters,
kernel_size=conv1_kernel_size,
padding='same',
activation='relu')(inp)
prev_layers = [first_conv]
for i in range(1, n_dil_layers + 1):
if i == 1:
prev_sum = first_conv
else:
prev_sum = kl.add(prev_layers)
conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
prev_layers.append(conv_output)
combined_conv = kl.add(prev_layers)
# De-conv
x = kl.Reshape((-1, 1, filters))(combined_conv)
x = kl.Conv2DTranspose(2*len(tasks), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
#x = kl.UpSampling2D((2, 1))(x)
#x = kl.Conv2DTranspose(len(tasks), kernel_size=(tconv_kernel_size2, 1), padding='same')(x)
#x = kl.UpSampling2D((2, 1))(x)
#x = kl.Conv2DTranspose(int(len(tasks)/2), kernel_size=(tconv_kernel_size3, 1), padding='same')(x)
out = kl.Reshape((-1, 2 * len(tasks)))(x)
# setup the output branches
outputs = []
losses = []
loss_weights = []
if use_profile:
output = [kl.Lambda(lambda x, i: x[:, :, (2 * i):(2 * i + 2)],
output_shape=(seq_len, 2),
name="profile/" + task,
arguments={"i": i})(out)
for i, task in enumerate(tasks)]
outputs += output
losses += [twochannel_multinomial_nll] * len(tasks)
loss_weights += [1] * len(tasks)
if use_counts:
pooled = kl.GlobalAvgPool1D()(combined_conv)
counts = [kl.Dense(2, name="counts/" + task)(pooled)
for task in tasks]
outputs += counts
losses += ["mse"] * len(tasks)
loss_weights += [c_task_weight] * len(tasks)
model = Model(inp, outputs)
model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
return model
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
i=1
def get_model(mfn, mkwargs, fkwargs, i):
"""Get the model"""
import datetime
mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
name = mfn + "_" + \
",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
"." + str(i)
i+=1
!mkdir -p {mdir}
ckp_file = f"{mdir}/{name}.h5"
all_kwargs = {**mkwargs, **fkwargs}
return eval(mfn)(**all_kwargs), name, ckp_file
# hyper-parameters
mfn = "seq_multitask_chipseq"
use_profile = True
use_counts = True
mkwargs = dict(filters=32,
conv1_kernel_size=21,
tconv_kernel_size=100,
n_dil_layers=6,
use_profile=use_profile,
use_counts=use_counts,
c_task_weight=10,
seq_len=train[0].shape[1],
lr=0.004)
fixed_kwargs = dict(
tasks=list(ds.task_specs)
)
import numpy as np
np.random.seed(20)
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train[0],
train[1],
batch_size=256,
epochs=100,
validation_data=valid[:2],
callbacks=[EarlyStopping(patience=5),
History(),
ModelCheckpoint(ckp_file, save_best_only=True)]
)
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll})
from basepair.eval import evaluate
evaluate(model, valid[0], valid[1])
model_dir = "/srv/scratch/amr1/chipseq/basepair/basepair/../data/processed/chipseq/exp/models/count+profile/"
import os
import theano
import numpy as np
import modisco
import modisco.tfmodisco_workflow.workflow
import os
from basepair.cli.schemas import DataSpec, HParams
from basepair.modisco import ModiscoResult
from scipy.spatial.distance import correlation
from concise.utils.helper import write_json, read_json
from basepair.data import numpy_minibatch
import h5py
import numpy as np
import keras.backend as K
from basepair.config import create_tf_session
from keras.models import load_model
from basepair.cli.evaluate import load_data
import matplotlib.pyplot as plt
plt.switch_backend('agg')
data = dict(train=train, valid=valid, test=test)
x, y, m = data['valid']
grad_wrt='profile/Sox2'
summary='weighted'
output_dir="outputs_profile_sox2"
hp = HParams.load("/srv/scratch/amr1/chipseq/basepair-workflow/hparams.yaml")
os.makedirs(output_dir, exist_ok=True)
write_json(dict(model_dir=os.path.abspath(model_dir),
grad_wrt=grad_wrt,
output_dir=output_dir,
summary=summary,
split='valid'),
os.path.join(output_dir, "kwargs.json"))
dtype, task = grad_wrt.split("/")
out = model.outputs[ds.task2idx(task, dtype)]
inp = model.inputs[0]
if "counts" in grad_wrt or summary == "count":
pos_strand_ginp_max = K.function([inp], K.gradients(out[:, 0], inp))
neg_strand_ginp_max = K.function([inp], K.gradients(out[:, 1], inp))
else:
if summary == 'weighted':
print("Using weighted stat")
pos_strand_ginp_max = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(out[:, :, 0])) *
out[:, :, 0], axis=-1), inp))
neg_strand_ginp_max = K.function([inp], K.gradients(K.sum(K.stop_gradient(K.softmax(out[:, :, 1])) *
out[:, :, 1], axis=-1), inp))
elif summary == "max":
print("Using max stat")
pos_strand_ginp_max = K.function([inp],
K.gradients(K.max(out[:, :, 0], axis=1), inp))
neg_strand_ginp_max = K.function([inp],
K.gradients(K.max(out[:, :, 1], axis=1), inp))
else:
print("Using avg stat")
pos_strand_ginp_max = K.function([inp],
K.gradients(K.mean(out[:, :, 0], axis=1), inp))
neg_strand_ginp_max = K.function([inp],
K.gradients(K.mean(out[:, :, 1], axis=1), inp))
# Pre-compute the predictions and bottlenecks
y_true = y[grad_wrt]
grads_pos = np.concatenate([pos_strand_ginp_max([batch])[0]
for batch in numpy_minibatch(x, 512)])
grads_neg = np.concatenate([neg_strand_ginp_max([batch])[0]
for batch in numpy_minibatch(x, 512)])
igrads_pos = grads_pos * x
igrads_neg = grads_neg * x
grads_pos_ext = grads_pos.reshape((grads_pos.shape[0], -1))
grads_neg_ext = grads_neg.reshape((grads_neg.shape[0], -1))
# compute the distances
distances = np.array([correlation(grads_neg_ext[i], grads_pos_ext[i])
for i in range(len(grads_neg_ext))])
# Setup different scores
top_distances = distances < hp.modisco.max_strand_distance # filter sites
hyp_scores = grads_pos + grads_neg
hyp_scores = hyp_scores[top_distances]
hyp_scores = hyp_scores - hyp_scores.mean(-1, keepdims=True)
onehot_data = x[top_distances]
scores = hyp_scores * onehot_data
# -------------------------------------------------------------
# run modisco
kw = hp.modisco.get_modisco_kwargs()
del kw['max_strand_distance']
del kw['threshold_for_counting_sign']
tfmodisco_results = modisco.tfmodisco_workflow.workflow.TfModiscoWorkflow(**kw,
seqlets_to_patterns_factory=
modisco.tfmodisco_workflow.seqlets_to_patterns.TfModiscoSeqletsToPatternsFactory(
trim_to_window_size=15,
initial_flank_to_add=5,
kmer_len=5, num_gaps=1,
num_mismatches=0,
final_min_cluster_size=60))(
task_names=[grad_wrt],
contrib_scores={grad_wrt: scores},
hypothetical_contribs={grad_wrt: hyp_scores},
one_hot=onehot_data)
# -------------------------------------------------------------
# save the results
output_path = os.path.join(output_dir, "modisco.h5")
if os.path.exists(output_path):
print("Output path exists")
output_path += "backup.h5"
grp = h5py.File(output_path)
tfmodisco_results.save_hdf5(grp)
# dump the top distances to file
top_distances.dump("{0}/included_samples.npy".format(output_dir))
distances.dump("{0}/distances.npy".format(output_dir))
def ic_scale(x):
from modisco.visualization import viz_sequence
background = np.array([0.27, 0.23, 0.23, 0.27])
return viz_sequence.ic_scale(x, background=background)
from basepair.modisco import ModiscoResult
mr = ModiscoResult(output_dir+"/modisco.h5")
incl = np.load(output_dir+"/included_samples.npy")
tasks = list(ds.task_specs)
legend=False
rc_vec=None
rc_fn=lambda x: x[::-1, ::-1]
start_vec=None
n_bootstrap=None
n_limit=35
seq_height=1.5
n_bootstrap=100
fpath_template=output_dir+"_figures/"
os.makedirs(fpath_template, exist_ok=True)
figsize=(12, 2.5)
tracks = {task: y[f"profile/{task}"][incl] for task in tasks}
x = x[incl]
import matplotlib.pyplot as plt
from concise.utils.plot import seqlogo_fig, seqlogo
print(mr.patterns())
def bootstrap_mean(x, n=100):
"""Bootstrap the mean computation"""
out = []
for i in range(n):
idx = pd.Series(np.arange(len(x))).sample(frac=1.0, replace=True)
out.append(x[idx].mean(0))
outm = np.stack(out)
return outm.mean(0), outm.std(0)
import pandas as pd
from matplotlib.ticker import FormatStrFormatter
from basepair.plots import show_figure
ylim=[0, 3]
%matplotlib inline
for i, pattern in enumerate(mr.patterns()):
j = i
seqs = mr.extract_signal(x)[pattern]
sequence = ic_scale(seqs.mean(axis=0))
if rc_vec is not None and rc_vec[i]:
sequence = rc_fn(sequence)
if start_vec is not None:
start = start_vec[i]
sequence = sequence[start:(start + width)]
n = len(seqs)
if n < n_limit:
continue
fig, ax = plt.subplots(1 + len(tracks), 1, sharex=True,
figsize=figsize,
gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height]})
ax[0].set_title(f"motif {i} ({n})")
for i, (k, y) in enumerate(tracks.items()):
signal = mr.extract_signal(y, rc_fn)[pattern]
if start_vec is not None:
start = start_vec[i]
signal = signal[:, start:(start + width)]
if n_bootstrap is None:
signal_mean = signal.mean(axis=0)
signal_std = signal.std(axis=0)
else:
signal_mean, signal_std = bootstrap_mean(signal, n=n_bootstrap)
if rc_vec is not None and rc_vec[i]:
signal_mean = rc_fn(signal_mean)
signal_std = rc_fn(signal_std)
ax[i].plot(np.arange(1, len(signal_mean) + 1), signal_mean[:, 0], label='pos')
if n_bootstrap is not None:
ax[i].fill_between(np.arange(1, len(signal_mean) + 1),
signal_mean[:, 0] - 2 * signal_std[:, 0],
signal_mean[:, 0] + 2 * signal_std[:, 0],
alpha=0.1)
# label='pos')
ax[i].plot(np.arange(1, len(signal_mean) + 1), signal_mean[:, 1], label='neg')
if n_bootstrap is not None:
ax[i].fill_between(np.arange(1, len(signal_mean) + 1),
signal_mean[:, 1] - 2 * signal_std[:, 1],
signal_mean[:, 1] + 2 * signal_std[:, 1],
alpha=0.1)
# label='pos')
ax[i].set_ylabel(f"{k}")
ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[i].spines['top'].set_visible(False)
ax[i].spines['right'].set_visible(False)
ax[i].spines['bottom'].set_visible(False)
ax[i].xaxis.set_ticks_position('none')
if isinstance(ylim[0], list):
ax[i].set_ylim(ylim[i])
if legend:
ax[i].legend()
seqlogo(sequence, ax=ax[-1])
ax[-1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[-1].set_ylabel("Inf. content")
ax[-1].spines['top'].set_visible(False)
ax[-1].spines['right'].set_visible(False)
ax[-1].spines['bottom'].set_visible(False)
ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))
if fpath_template is not None:
plt.savefig(fpath_template.format(j) + '.png', dpi=600)
plt.savefig(fpath_template.format(j) + '.pdf', dpi=600)
plt.close(fig) # close the figure
show_figure(fig)
plt.show()