from basepair.utils import read_pkl
from pathlib import Path
from basepair.exp.paper.config import motifs, profile_mapping
from basepair.plot.config import paper_config, get_figsize
from basepair.exp.chipnexus.perturb_vdom import vdom_motif_pair, plot_spacing_hist
from basepair.exp.chipnexus.spacing import remove_edge_instances, get_motif_pairs
from basepair.plot.heatmaps import RowQuantileNormalizer, QuantileTruncateNormalizer
from tqdm import tqdm
from basepair.data import NumpyDataset
from basepair.exp.chipnexus.perturb import *
from copy import deepcopy
from plotnine import *
import plotnine
import warnings
warnings.filterwarnings("ignore")
from basepair.config import get_data_dir
paper_config()
ddir = get_data_dir()
# Common paths
model_dir = Path(f"{ddir}/processed/chipnexus/exp/models/oct-sox-nanog-klf/models/n_dil_layers=9/")
modisco_dir = model_dir / f"modisco/all/deeplift/profile/"
output_dir = modisco_dir
pairs = get_motif_pairs(motifs)
# load the data
dataset_dir = output_dir / 'perturbation-analysis'
motif_pair_lpdata = read_pkl(dataset_dir / 'motif_pair_lpdata.pkl')
dfab = pd.read_csv(dataset_dir / 'dfab.csv.gz')
# TODO - de-hardcode
tasks = ['Oct4', 'Sox2', 'Nanog', 'Klf4']
# fix one-hot-encoding
from basepair.exp.chipnexus.perturb_vdom import imp2seq
for pair in motif_pair_lpdata:
for xy in ['x', 'y']:
for wn in ['wide', 'narrow']:
for task in ['Oct4', 'Sox2', 'Klf4', 'Nanog']:
motif_pair_lpdata[pair][xy]['ref'][wn][task]['imp']['count'] = motif_pair_lpdata[pair][xy]['ref'][wn][task]['imp']['count'] * imp2seq(motif_pair_lpdata[pair][xy]['ref'][wn][task]['imp']['profile'])
profile_width = 200
for motif_pair_name, lpdata in tqdm(motif_pair_lpdata.items()):
motif_pair = list(motif_pair_name.split("<>"))
dfab_subset = remove_edge_instances(dfab[dfab.motif_pair == motif_pair_name],
profile_width=profile_width)
motif_pair_lpdata[motif_pair_name]['dfab'] = dfab_subset
# Load all files from disk
%time ref = NumpyDataset.load(dataset_dir / 'ref.h5')
%time single_mut = NumpyDataset.load(dataset_dir / 'single_mut.h5')
%time double_mut = NumpyDataset.load(dataset_dir / 'double_mut.h5')
%time dfi_subset = pd.read_csv(dataset_dir / 'dfi_subset.csv.gz')
%time dfab = pd.read_csv(dataset_dir / 'dfab.csv.gz')
motif_pair_lpdata_bak = deepcopy(motif_pair_lpdata)
motif_pair_lpdata = {}
for motif_pair in pairs:
motif_pair_name = "<>".join(motif_pair)
# motif_pair = list(motif_pair_name.split("<>"))
dfab_subset = remove_edge_instances(dfab[dfab.motif_pair == motif_pair_name], profile_width=profile_width)
pdata = ParturbationDataset(dfab_subset, ref, single_mut, double_mut, profile_width=profile_width)
motif_pair_lpdata[motif_pair_name] = pdata.load_all(num_workers=1)
# store also dfab
motif_pair_lpdata[motif_pair_name]['dfab'] = dfab_subset
# sort_idx = np.argsort(pdata.dfab.center_diff)
write_pkl(motif_pair_lpdata, dataset_dir / 'motif_pair_lpdata.incl-whole.pkl')