MeCo/zero-cost-nas/notebooks/nasbench101_correlations.ipynb
HamsterMimi 189df25fd3 upload
2023-05-04 13:09:03 +08:00

154 lines
3.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os, pickle, sys\n",
"import matplotlib.pyplot as plt\n",
"from scipy import stats\n",
"import numpy as np\n",
"import glob\n",
"from tqdm import tqdm\n",
"from prettytable import PrettyTable"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n"
]
}
],
"source": [
"d = '../results_release/nasbench1/proxies'\n",
"runs = []\n",
"processed = set()\n",
"\n",
"for f in tqdm(os.listdir(d)):\n",
" pf = open(os.path.join(d,f),'rb')\n",
" while 1:\n",
" try:\n",
" p = pickle.load(pf)\n",
" if p['hash'] in processed:\n",
" continue\n",
" processed.add(p['hash'])\n",
" runs.append(p)\n",
" except EOFError:\n",
" break\n",
" pf.close()\n",
"with open('../data/nasbench1_accuracy.p','rb') as f:\n",
" all_accur = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"423624 423624\n"
]
}
],
"source": [
"print(len(runs),len(all_accur))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"../results_release/nasbench1/proxies 423624\n",
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
"| Dataset | grad_norm | snip | grasp | fisher | synflow | jacob_cov |\n",
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
"| CIFAR10 | 0.198 | 0.164 | 0.448 | 0.257 | 0.372 | 0.378 |\n",
"+---------+-----------+-------+-------+--------+---------+-----------+\n"
]
}
],
"source": [
"t=None\n",
"\n",
"print(d, len(runs))\n",
"metrics={}\n",
"for k in runs[0]['logmeasures'].keys():\n",
" metrics[k] = []\n",
"acc = []\n",
"hashes = []\n",
"\n",
"if t is None:\n",
" hl=['Dataset']\n",
" hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n",
" t = PrettyTable(hl)\n",
"\n",
"for r in runs:\n",
" for k,v in r['logmeasures'].items():\n",
" metrics[k].append(v)\n",
" \n",
" acc.append(all_accur[r['hash']][0])\n",
" hashes.append(r['hash'])\n",
"\n",
"res = []\n",
"for k in hl:\n",
" if k=='Dataset':\n",
" continue\n",
" v = metrics[k]\n",
" cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n",
" #print(f'{k} = {cr}')\n",
" res.append(round(cr,3))\n",
"\n",
"ds = 'CIFAR10'\n",
"t.add_row([ds]+res)\n",
"\n",
"print(t)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}