154 lines
3.6 KiB
Plaintext
154 lines
3.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os, pickle, sys\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from scipy import stats\n",
|
|
"import numpy as np\n",
|
|
"import glob\n",
|
|
"from tqdm import tqdm\n",
|
|
"from prettytable import PrettyTable"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"d = '../results_release/nasbench1/proxies'\n",
|
|
"runs = []\n",
|
|
"processed = set()\n",
|
|
"\n",
|
|
"for f in tqdm(os.listdir(d)):\n",
|
|
" pf = open(os.path.join(d,f),'rb')\n",
|
|
" while 1:\n",
|
|
" try:\n",
|
|
" p = pickle.load(pf)\n",
|
|
" if p['hash'] in processed:\n",
|
|
" continue\n",
|
|
" processed.add(p['hash'])\n",
|
|
" runs.append(p)\n",
|
|
" except EOFError:\n",
|
|
" break\n",
|
|
" pf.close()\n",
|
|
"with open('../data/nasbench1_accuracy.p','rb') as f:\n",
|
|
" all_accur = pickle.load(f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"423624 423624\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(len(runs),len(all_accur))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"../results_release/nasbench1/proxies 423624\n",
|
|
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
|
|
"| Dataset | grad_norm | snip | grasp | fisher | synflow | jacob_cov |\n",
|
|
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
|
|
"| CIFAR10 | 0.198 | 0.164 | 0.448 | 0.257 | 0.372 | 0.378 |\n",
|
|
"+---------+-----------+-------+-------+--------+---------+-----------+\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"t=None\n",
|
|
"\n",
|
|
"print(d, len(runs))\n",
|
|
"metrics={}\n",
|
|
"for k in runs[0]['logmeasures'].keys():\n",
|
|
" metrics[k] = []\n",
|
|
"acc = []\n",
|
|
"hashes = []\n",
|
|
"\n",
|
|
"if t is None:\n",
|
|
" hl=['Dataset']\n",
|
|
" hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n",
|
|
" t = PrettyTable(hl)\n",
|
|
"\n",
|
|
"for r in runs:\n",
|
|
" for k,v in r['logmeasures'].items():\n",
|
|
" metrics[k].append(v)\n",
|
|
" \n",
|
|
" acc.append(all_accur[r['hash']][0])\n",
|
|
" hashes.append(r['hash'])\n",
|
|
"\n",
|
|
"res = []\n",
|
|
"for k in hl:\n",
|
|
" if k=='Dataset':\n",
|
|
" continue\n",
|
|
" v = metrics[k]\n",
|
|
" cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n",
|
|
" #print(f'{k} = {cr}')\n",
|
|
" res.append(round(cr,3))\n",
|
|
"\n",
|
|
"ds = 'CIFAR10'\n",
|
|
"t.add_row([ds]+res)\n",
|
|
"\n",
|
|
"print(t)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|