Goal

  • find the best model for osnk-pstat-sall-smad-zfp

Conclusions

  • I'll use c_task_weight=5
In [67]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from basepair.imports import *
In [68]:
# Common paths
models_dir = Path(f"{ddir}/processed/chipnexus/exp/models/osnk-pstat-sall-smad-zfp/models/")
dirs = !ls {models_dir}
In [69]:
def load_metrics(mdir):
    try:
        df = pd.read_csv(os.path.join(mdir, 'history.csv'))
        return {"exp": os.path.basename(mdir), **dict(df.loc[df.val_loss.idxmin()])}
    except:
        return {"exp": os.path.basename(mdir)}
In [70]:
dfm = pd.DataFrame([load_metrics(models_dir / md) for md in dirs])
dfm = dfm.dropna()
dfm = dfm.set_index('exp')
dfmv = dfm[[c for c in dfm.columns if c.startswith("val")]]
dfmv = dfmv.sort_values("val_loss")
In [71]:
dfmv
Out[71]:
val_counts/Klf4_loss val_counts/Nanog_loss val_counts/Oct4_loss val_counts/Pstat3_loss val_counts/Sall4_loss val_counts/Smad3_loss val_counts/Sox2_loss val_counts/Zfp281_loss val_loss val_profile/Klf4_loss val_profile/Nanog_loss val_profile/Oct4_loss val_profile/Pstat3_loss val_profile/Sall4_loss val_profile/Smad3_loss val_profile/Sox2_loss val_profile/Zfp281_loss
exp
c_task_weight=1 0.6297 0.6785 0.4490 0.6577 0.4881 0.9742 0.2960 0.4403 6096.1603 633.8127 649.8054 844.0361 632.1145 898.0561 1121.1224 544.6467 767.9529
best 0.5778 0.6450 0.4379 0.6617 0.4706 0.9754 0.2876 0.4170 6111.4525 629.7085 651.2944 844.2224 632.2860 898.7354 1119.8374 544.6091 768.3939
c_task_weight=5 0.5778 0.6450 0.4379 0.6617 0.4706 0.9754 0.2876 0.4170 6111.4525 629.7085 651.2944 844.2224 632.2860 898.7354 1119.8374 544.6091 768.3939
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
n_dil_layers=6 0.5824 0.6753 0.4484 0.6481 0.4745 0.9707 0.2904 0.4481 6225.5293 638.7524 677.0816 857.5369 634.0042 906.1864 1142.9266 547.1171 776.5446
c_task_weight=50 0.5861 0.6263 0.4494 0.6591 0.4658 0.9534 0.2994 0.3868 6326.2509 634.2737 654.8904 845.8867 632.7123 899.8330 1122.3487 545.2802 769.7121
c_task_weight=100 0.6170 0.6360 0.4571 0.6593 0.4859 0.9623 0.2952 0.4306 6600.0668 639.3190 660.3210 851.7635 633.9174 908.4818 1127.5448 547.1889 777.1959

21 rows × 17 columns

In [72]:
fig, ax = plt.subplots(figsize=(20,8))
sns.heatmap(dfmv.rank(), annot=dfmv, ax=ax)
Out[72]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f156b5f9898>