import basepair
from basepair.config import get_data_dir, create_tf_session
from keras.models import load_model
from basepair.datasets import *
from basepair import datasets
from basepair.preproc import AppendTotalCounts, transform_data, resize_interval
from basepair.plots import regression_eval
from basepair.BPNet import BPNetPredictor
from basepair.utils import write_pkl
from basepair.imports import *
# from basepair.preproc import resize_interval
from kipoiseq.transforms.functional import resize_interval
create_tf_session(0)
tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4']
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
genome_file = "/mnt/data/pipeline_genome_data/mm10/mm10.chrom.sizes"
klf4_bed = "/users/avsec/workspace/basepair-workflow/data/klf4_sites_mm10.bed"
oct4_bed = "/users/avsec/workspace/basepair-workflow/data/oct4_sites_mm10.bed"
klf4_oct4_windowed_bed = "/users/avsec/workspace/basepair-workflow/data/klf4_oct4_windowed_sites_mm10.1kb.bed"
ds = DataSpec.load(model_dir / 'dataspec.yaml')
cat {klf4_bed} {oct4_bed}
# make windowed regions
!cat {klf4_bed} {oct4_bed} | bedtools makewindows -w 1000 -b stdin -i srcwinnum > {klf4_oct4_windowed_bed}
Regions from Khyati. Google spreadsheet link
regions_from_khyati = """
chr6 122,707,340 122,707,540 Esrrb,Oct4,Sox2->Nanog
chr1 180,933,774 180,933,974 Esrrb,Oct4,Sox2->Lefty
chr5 77262224 77,262,424 Esrrb,Oct4,Sox2->REST
chr6 122707331 122,707,531 Klf4,Pbx1,Oct4,Sox2->Nanog
chr4 55,475,492 55,475,692 Esrrb,Oct4,Sox2,Stat3->Klf4
chr3 34,756,830 34,757,030 Oct4,Sox2,Nanog,Klf4,NR5A2, Sat3,Esrrb,Smad1,Ncoa3->Sox2(dist)
chr3 34,758,000 34,758,200 Oct4,Sox2,Nanog,Klf4,NR5A2, Sat3,Esrrb,Smad1,Ncoa3->Sox2(dist)
chr3 34,761,355 34,761,555 Oct4,Sox2,Nanog,Klf4,NR5A2, Sat3,Esrrb,Smad1,Ncoa3->Sox2(dist)
chr3 34,654,000 34,654,200 Oct4,Sox2,Nanog,P300,Smad1->Sox2(prox)
"""
from io import StringIO
df = pd.read_csv(StringIO(regions_from_khyati), sep='\t', header=None)
df.columns = ['chrom', 'start', 'stop', 'name']
df['start'] = df['start'].str.replace(",", "").astype(int)
df['stop'] = df['stop'].str.replace(",", "").astype(int)
df['score'] = 0
df['strand'] = "."
df["stop"] - df["start"]
new_intervals_from_khyati = list(BedTool.from_dataframe(df))
# TODO get the interesting regions from the genome browser
bt = BedTool(klf4_bed).cat(BedTool(oct4_bed), postmerge=False)
intervals = list(bt)
bt_windowed = BedTool(klf4_oct4_windowed_bed)
intervals_windowed = [resize_interval(interval, 1000) for interval in bt_windowed]
resized_intervals = [resize_interval(interval, 1000) for interval in intervals]
# motif widths
for i in intervals:
print(i.stop - i.start, i.name)
bpnet = BPNetPredictor.from_mdir(model_dir)
bpnet.plot_predict_grad(resized_intervals[:1], ds, xlim=[350, 650])
chr17 35503943 35506057 Distal and Proximal Oct4 Enhancers, Combined 0
oct4_enhancer = [Interval('chr17', 35504050-500, 35504050+500), Interval('chr17', 35505100-500, 35505100+500)]
use_intervals = intervals + oct4_enhancer + new_intervals_from_khyati
len(intervals + oct4_enhancer)
len(new_intervals_from_khyati)
# examples from Khyati start from 8 onwards
resized_intervals = [resize_interval(interval, 1000) for interval in use_intervals]
len(resized_intervals)
preds = bpnet.predict(resized_intervals)
pattern_names = [
("Oct4-Sox2", "metacluster_0/pattern_0"),
("Errb", "metacluster_0/pattern_1"),
("Sox2", "metacluster_0/pattern_2"),
("Nanog", "metacluster_0/pattern_3"),
("Klf4", "metacluster_2/pattern_0"),
#("Klf4-1", "metacluster_2/pattern_2"),
#("Klf4-2", "metacluster_2/pattern_3"),
]
modisco_dir = model_dir / "modisco/by_peak_tasks/weighted/Oct4"
mr = ModiscoResult(modisco_dir / "modisco.h5")
mr.open()
seq = np.stack([preds[i]['seq'] for i in range(len(preds))])
contrib = {t: np.stack([mean(preds[i]['grads'][ti]['profile'].values()) for i in range(len(preds))]) * seq
for ti, t in enumerate(bpnet.tasks)}
contrib['Oct4'].shape
dfm_norm = pd.read_csv(modisco_dir / "centroid_seqlet_matches.csv")
dfm_norm.info()
trim_frac = 0.08
n_jobs = 1
dfl = []
for tf, pattern_name in tqdm(pattern_names):
pattern = mr.get_pattern(pattern_name).trim_seq_ic(trim_frac)
match, importance = pattern.scan_importance(contrib, seq, tasks,
n_jobs=n_jobs, verbose=False)
seq_match = pattern.scan_seq(seq, n_jobs=n_jobs, verbose=False)
dfm = pattern.get_instances(tasks, match, importance, seq_match,
norm_df=dfm_norm[dfm_norm.pattern == pattern_name],
verbose=False, plot=False)
dfm['tf'] = tf
dfl.append(dfm)
dfp = pd.concat(dfl)
dfp = dfp[dfp.seq_match > 0]
from basepair.modisco.core import dfi2seqlets
def dfi2seqlets(dfi, short_name=False):
"""Convert the data-frame produced by pattern.get_instances()
to a list of Seqlets
Args:
dfi: pd.DataFrame returned by pattern.get_instances()
short_name: if True, short pattern name will be used for the seqlet name
Returns:
Seqlet list
"""
def extract_name(row):
if short_name:
return shorten_pattern(row.pattern)
else:
return row.pattern
return [Seqlet(row.example_idx,
row.pattern_start,
row.pattern_end,
row.tf,
row.strand)
for i, row in dfi.iterrows()]
example_idx = 0
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 150)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 650], fig_width=30, same_ylim=True)
example_idx = 1
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 150)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 650], fig_width=30, same_ylim=True)
example_idx = 3
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 150)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 650], fig_width=30, same_ylim=True)
example_idx = 4
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = len(preds) - 2
example_idx
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = len(preds) - 1
example_idx
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 8
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 9
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 10
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 11
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 12
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 13
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 14
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 15
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
example_idx = 16
resized_intervals[example_idx]
resized_intervals[example_idx].name
query = (dfp.match_weighted_p > 0.2) & (dfp.imp_weighted_p > 0) & (dfp.example_idx == example_idx)
dfp[query & (np.abs(dfp.pattern_center - 500) < 200)][['tf', 'pattern_center', 'match_weighted_p', 'match_weighted_cat',
'imp_weighted', 'imp_weighted_p', 'imp_weighted_cat','seq_match']]
seqlets = dfi2seqlets(dfp[(dfp.example_idx == example_idx) & query])
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[400, 700], fig_width=30, same_ylim=True)
bpnet.plot_predict_grad([resized_intervals[example_idx]], ds, seqlets=seqlets, xlim=[300, 700], fig_width=30, same_ylim=True)
pattern_name = 'metacluster_2/pattern_0'
pattern = mr.get_pattern(pattern_name).trim_seq_ic(trim_frac)
match, importance = pattern.scan_importance(contrib, seq, tasks,
n_jobs=n_jobs, verbose=False)
seq_match = pattern.scan_seq(seq, n_jobs=n_jobs, verbose=False)
dfm = pattern.get_instances(tasks, match, importance, seq_match,
norm_df=dfm_norm[dfm_norm.pattern == pattern_name],
verbose=False, plot=False)
dfm['tf'] = tf
from kipoi.data_utils import get_dataset_item
from basepair.plot.tracks import filter_tracks, plot_tracks
def prefix_dict(d, prefix):
return {prefix + d: v for d,v in d.items()}
tasks
plot_dict = {**{"m/" + t: match[0,:,ti].max(axis=-1) for ti, t in enumerate(tasks)}, **get_dataset_item(contrib, 0)}
plot_tracks(filter_tracks(plot_dict, [300, 700]));