# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
# Common paths
models_dir = Path(f"{ddir}/processed/chipnexus/exp/models/osnk-pstat-sall-smad-zfp/models/")
dirs = !ls {models_dir}
def load_metrics(mdir):
try:
df = pd.read_csv(os.path.join(mdir, 'history.csv'))
return {"exp": os.path.basename(mdir), **dict(df.loc[df.val_loss.idxmin()])}
except:
return {"exp": os.path.basename(mdir)}
dfm = pd.DataFrame([load_metrics(models_dir / md) for md in dirs])
dfm = dfm.dropna()
dfm = dfm.set_index('exp')
dfmv = dfm[[c for c in dfm.columns if c.startswith("val")]]
dfmv = dfmv.sort_values("val_loss")
dfmv
fig, ax = plt.subplots(figsize=(20,8))
sns.heatmap(dfmv.rank(), annot=dfmv, ax=ax)