import h5py
f = h5py.File("/mnt/lab_data2/msharmin/oc-atlas/DanSkinData/fold_0_ggr/model_preds_early/ggr.predictions.h5")
print(list(f))
unfold_logits = f['logits'][:]
labels = f['labels'][:]
cur_seqs = f['features'][:]
#activations = f['CONV1_ACTIVATION'][:]
f.close()
from matlas.aitac_motifs import get_task_cors
import numpy as np
correlations = get_task_cors(labels, unfold_logits, verbose=True)
idx = np.argwhere(np.asarray(correlations)>0.75).squeeze()
print(len(idx), len(correlations))
x2 = cur_seqs[idx, :, :]
y2 = labels[idx, :]
pred_full_model2 = unfold_logits[idx,:]
#correlations2 = get_task_cors(y2, pred_full_model2, verbose=True)
import numpy as np
filter_predictions = np.load("/mnt/lab_data2/msharmin/oc-atlas/DanSkinData/fold_0_ggr/result_early/filter_predictions.npy")
filter_predictions.shape
print(filter_predictions.shape)
print(y2.shape)
correlations2 = np.array(correlations2)
print(correlations2.shape)
from matlas.aitac_motifs import get_filt_corr
filt_corr, filt_infl, ave_filt_infl = get_filt_corr(filter_predictions, y2, correlations2, verbose=True)
from matlas.aitac_motifs import get_memes
pwm, act_ind, nseqs, activated_OCRs, n_activated_OCRs, OCR_matrix = get_memes(
activations.squeeze(), x2.squeeze(), y2,
output_file_path="/mnt/lab_data2/msharmin/oc-atlas/DanSkinData/fold_0_ggr/result_early/".format(i))
from matlas.matches import DenovoAitac
motif_name = 'result_early'
aitacdir = "/mnt/lab_data2/msharmin/oc-atlas/DanSkinData/fold_0_ggr/{}".format(motif_name)
ob = DenovoAitac(aitacdir, influence=ave_filt_infl)
# ob.fetch_tomtom_matches(
# meme_db="/mnt/lab_data/kundaje/users/msharmin/annotations/HOCOMOCOv11_core_pwms_HUMAN_mono.renamed.nonredundant.annotated.meme",
# database_name="HOCOMOCO.nonredundant.annotated",
# save_report=True, tomtom_dir= "{0}/{1}_tomtomout".format(aitacdir, "HOCOMOCO.nonredundant.annotated"))
ob.load_matched_motifs(database_name="HOCOMOCO.nonredundant.annotated")
ob.get_motif_per_celltype(match_threshold=0.05, database_name="HOCOMOCO.nonredundant.annotated")
pattern_tab, pattern_dict = ob.visualize_pattern_table()
tf_tab, tf_dict = ob.visualize_tf_table("Aitac")
from vdom.helpers import (b, summary, details)
from IPython.display import display
display(details(summary('Click here for ', b('Denovo Patterns'), ' by ', b('{}'.format('Aitac')),
' in ', b(motif_name),
": #{}".format(len(pattern_dict)),
), pattern_tab))
display(details(summary('Click here for ', b('Motifs'), ' by ', b('{}'.format('Aitac')),
' in ', b(motif_name),
": #{}".format(len(tf_dict)),
), tf_tab))