from basepair.cli.schemas import DataSpec, TaskSpec
from basepair.datasets import chip_exo_nexus
from basepair.preproc import AppendTotalCounts
from basepair.config import get_data_dir, create_tf_session
# Use gpus 3, 5
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5, 7"
import keras.layers as kl
from keras.optimizers import Adam
from keras.models import Model
import keras.backend as K
from concise.utils.helper import get_from_module
from basepair.losses import twochannel_multinomial_nll
from basepair.layers import SpatialLifetimeSparsity
import keras.layers as kl
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, History
from keras.models import Model, load_model
def get_model(mfn, mkwargs, fkwargs, i):
"""Get the model"""
import datetime
mdir = f"{ddir}/processed/chipseq/exp/models/count+profile"
name = mfn + "_" + \
",".join([f'{k}={v}' for k,v in mkwargs.items()]) + \
"." + str(i)
i+=1
!mkdir -p {mdir}
ckp_file = f"{mdir}/{name}.h5"
all_kwargs = {**mkwargs, **fkwargs}
return eval(mfn)(**all_kwargs), name, ckp_file
BED_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
Sox2_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Sox2/"
Oct4_BW_DIR = f"/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/Oct4/"
import pandas as pd
import numpy as np
from pybedtools import BedTool, Interval
from basepair.config import get_data_dir
from basepair.preproc import bin_counts
from concise.utils.helper import get_from_module
from tqdm import tqdm
from concise.preprocessing import encodeDNA
from random import Random
import joblib
from basepair.preproc import resize_interval
from genomelake.extractors import FastaExtractor, BigwigExtractor
from kipoi.data_utils import get_dataset_item
import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_chipnexus_data(bed_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.narrowPeak",
peak_fasta_file=f"{BED_DIR}//Sox2_123b_1_ppr.IDR0.05.filt.summit_centered_200bp.fasta",
bigwigs={"cuts_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
"cuts_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
}
):
"""Loads the dataframe for sox2
"""
from concise.utils.fasta import read_fasta
import pyBigWig
fas = read_fasta(peak_fasta_file)
bed = BedTool(bed_file)
assert len(fas) == len(bed)
bigwig_obj = {k: pyBigWig.open(v) for k, v in bigwigs.items()}
# cuts_pos = []
# cuts_neg = []
l = []
for interval in tqdm(bed):
l.append({"chr": interval.chrom,
"start": interval.start,
"end": interval.stop,
**{k: np.nan_to_num(bw.values(interval.chrom,
interval.start,
interval.stop,
numpy=True))
for k, bw in bigwig_obj.items()}
})
dfc = pd.DataFrame(l)
dfc['seq'] = list(fas.values())
dfc['seq_id'] = list(fas)
dfc['seq'] = dfc.seq.str.upper()
return dfc
def sox2_oct4_peaks_sox2(valid_chr=['chr2', 'chr3', 'chr4'],
test_chr=['chr1', 'chr8', 'chr9']):
"""
The default chromomsome split is roughly 60/20/20
"""
for v in valid_chr:
assert v not in test_chr
dfc = get_chipnexus_data(
bigwigs={"sox2_pos": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.pos_strand.bw",
"sox2_neg": f"{Sox2_BW_DIR}/sox2_pooled_reps_1b_2b_4b.neg_strand.bw",
"oct4_pos": f"{Oct4_BW_DIR}/Oct4_1234.pos.bw",
"oct4_neg": f"{Oct4_BW_DIR}/Oct4_1234.neg.bw"})
seq = encodeDNA(dfc.seq)
# Prepare the signal
sox2_pos = np.stack(dfc.sox2_pos)
sox2_neg = np.stack(dfc.sox2_neg)
oct4_pos = np.stack(dfc.oct4_pos)
oct4_neg = np.stack(dfc.oct4_neg)
ids = dfc.seq_id
# Data splits
is_test = dfc.chr.isin(test_chr)
is_valid = dfc.chr.isin(valid_chr)
is_train = (~is_test) & (~is_valid)
sox2_cuts = np.stack([sox2_pos, sox2_neg], axis=-1)
oct4_cuts = np.stack([oct4_pos, oct4_neg], axis=-1)
return tuple(((seq[subset], # x
{"sox2": sox2_cuts[subset], # y
"oct4": oct4_cuts[subset]},
dfc[subset]) # metadata
for subset in [is_train, is_valid, is_test]))
# hyper-parameters
from basepair.models import seq_multitask
mfn2 = "seq_multitask"
use_profile = True
use_counts = True
mkwargs2 = dict(filters=32,
conv1_kernel_size=21,
tconv_kernel_size=25,
n_dil_layers=6,
use_profile=use_profile,
use_counts=use_counts,
c_task_weight=10,
lr=0.004)
data2 = sox2_oct4_peaks_sox2()
from basepair.preproc import transform_data
train_nex, valid_nex, test_nex = transform_data(data2, use_profile, use_counts)
ddir = get_data_dir()
bdir = "/srv/scratch/amr1/chipseq/sox2-oct4-chipseq/"
ds = DataSpec(task_specs={"Sox2": TaskSpec(task="Sox2",
pos_counts=f"{bdir}/Sox2/pos.bw",
neg_counts=f"{bdir}/Sox2/neg.bw",
peaks=f"{bdir}/Sox2/Sox2_1_rep1-pr.IDR0.05.filt.12-col.bed.gz",
),
"Oct4": TaskSpec(task="Oct4",
pos_counts=f"{bdir}/Oct4/pos2.bw",
neg_counts=f"{bdir}/Oct4/neg2.bw",
peaks=f"{bdir}/Oct4/Oct4_12_ppr.IDR0.05.filt.12-col.bed.gz",
)
},
fasta_file="/mnt/data/pipeline_genome_data/mm10/mm10_no_alt_analysis_set_ENCODE.fasta"
)
fixed_kwargs = dict(
tasks=list(ds.task_specs)
)
i=1
i += 1
model2, name2, ckp_file2 = get_model(mfn2, mkwargs2, fixed_kwargs, i)
history2 = model2.fit(train_nex[0],
train_nex[1],
batch_size=256,
epochs=100,
validation_data=valid_nex[:2],
callbacks=[EarlyStopping(patience=5),
History(),
ModelCheckpoint(ckp_file2, save_best_only=True)]
)
# get the best model
model2 = load_model(ckp_file2, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll,
"SpatialLifetimeSparsity": SpatialLifetimeSparsity})
from basepair.eval import evaluate
evaluate(model2, valid_nex[0], valid_nex[1])
def seq_multitask_newloss(filters=21,
conv1_kernel_size=21,
tconv_kernel_size=25,
n_dil_layers=6,
lr=0.004,
c_task_weight=100,
use_profile=True,
use_counts=True,
tasks=['sox2', 'oct4'],
outputs_per_task=2,
task_use_bias=False,
seq_len=201): # TODO - automatically infer sequence length
"""
Dense
Args:
c_task_weights: how to upweight the count-prediction task
task_use_bias (bool or a list of bools): if True, a
bias term is assumed to be provided at the input
"""
# TODO - split the body of this model into multiple subparts:
# - encoder
# - profile_decoder
# - profile_decoder_w_bias
# - counts_decoder
# - counts_decoder_w_bias
if isinstance(outputs_per_task, int):
outputs_per_task = [outputs_per_task for i in range(len(tasks))]
else:
assert len(tasks) == len(outputs_per_task)
if isinstance(task_use_bias, bool):
task_use_bias = [task_use_bias for i in range(len(tasks))]
else:
assert len(tasks) == len(task_use_bias)
# TODO - build the reverse complement symmetry into the model
inp = kl.Input(shape=(seq_len, 4), name='seq')
first_conv = kl.Conv1D(filters,
kernel_size=conv1_kernel_size,
padding='same',
activation='relu')(inp)
bias_profile_inputs = {task: kl.Input(shape=(seq_len, outputs_per_task[i]), name=f"bias/profile/{task}")
for i, task in enumerate(tasks) if task_use_bias[i]}
bias_counts_inputs = [kl.Input(shape=(outputs_per_task[i], ), name=f"bias/counts/{task}")
for i, task in enumerate(tasks) if task_use_bias[i]]
prev_layers = [first_conv]
for i in range(1, n_dil_layers + 1):
if i == 1:
prev_sum = first_conv
else:
prev_sum = kl.add(prev_layers)
conv_output = kl.Conv1D(filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2**i)(prev_sum)
prev_layers.append(conv_output)
combined_conv = kl.add(prev_layers)
# De-conv
x = kl.Reshape((-1, 1, filters))(combined_conv)
x = kl.Conv2DTranspose(sum(outputs_per_task), kernel_size=(tconv_kernel_size, 1), padding='same')(x)
out = kl.Reshape((-1, sum(outputs_per_task)))(x)
# batch x seqlen x tasks*2 array
# need another array of the same length
# setup the output branches
outputs = []
losses = []
loss_weights = []
if use_profile:
# TODO - use a different loss function for the same profiles
start_idx = np.cumsum([0] + outputs_per_task[:-1])
end_idx = np.cumsum(outputs_per_task)
def get_output_name(task):
if task in bias_profile_inputs:
return "lambda/profile/" + task
else:
return "profile/" + task
output = [kl.Lambda(lambda x, i, sidx, eidx: x[:, :, sidx:eidx],
output_shape=(seq_len, outputs_per_task[i]),
name=get_output_name(task),
arguments={"i": i, "sidx": start_idx[i], "eidx": end_idx[i]})(out)
for i, task in enumerate(tasks)]
for i, task in enumerate(tasks):
if task in bias_profile_inputs:
output_with_bias = kl.concatenate([output[i],
bias_profile_inputs[task]], axis=-1) # batch x seqlen x (2+2)
output[i] = kl.Conv1D(outputs_per_task[i],
1,
name="profile/" + task)(output_with_bias)
outputs += output
losses += [basepair.losses.get(f"mc_multinomial_nll_{nt}") for nt in outputs_per_task]
loss_weights += [1] * len(tasks)
if use_counts:
pooled = kl.GlobalAvgPool1D()(combined_conv)
if bias_counts_inputs:
pooled = kl.concatenate([pooled] + bias_counts_inputs, axis=-1) # add bias as additional features
counts = [kl.Dense(outputs_per_task[i], name="counts/" + task)(pooled)
for i, task in enumerate(tasks)]
outputs += counts
losses += ["mae"] * len(tasks)
loss_weights += [c_task_weight] * len(tasks)
model = Model([inp] + list(bias_profile_inputs.values()) + bias_counts_inputs, outputs)
model.compile(Adam(lr=lr), loss=losses, loss_weights=loss_weights)
return model
mfn = "seq_multitask_newloss"
mkwargs = dict(filters=32,
conv1_kernel_size=21,
tconv_kernel_size=25,
n_dil_layers=6,
use_profile=use_profile,
use_counts=use_counts,
c_task_weight=10,
lr=0.004)
import basepair
i += 1
model, name, ckp_file = get_model(mfn, mkwargs, fixed_kwargs, i)
history = model.fit(train_nex[0],
train_nex[1],
batch_size=256,
epochs=100,
validation_data=valid_nex[:2],
callbacks=[EarlyStopping(patience=5),
History(),
ModelCheckpoint(ckp_file, save_best_only=True)]
)
# get the best model
model = load_model(ckp_file, custom_objects={"twochannel_multinomial_nll": twochannel_multinomial_nll,
"SpatialLifetimeSparsity": SpatialLifetimeSparsity})
evaluate(model, valid_nex[0], valid_nex[1])
from basepair.plots import regression_eval
y_pred_new = model.predict(test_nex[0])
y_pred_old = model2.predict(test_nex[0])
print("OLD SOX2")
regression_eval(test_nex[1][2].mean(-1), y_pred_old[ds.task2idx("Sox2", 'counts')].mean(-1))
print("NEW SOX2")
regression_eval(test_nex[1][2].mean(-1), y_pred_new[ds.task2idx("Sox2", 'counts')].mean(-1))
print("OLD OCT4")
regression_eval(test_nex[1][3].mean(-1), y_pred_old[ds.task2idx("Oct4", 'counts')].mean(-1))
print("NEW OCT4")
regression_eval(test_nex[1][3].mean(-1), y_pred_new[ds.task2idx("Oct4", 'counts')].mean(-1))