With table generator
This commit is contained in:
parent
6f97d1be37
commit
45901dd7ec
@ -13,4 +13,10 @@ conda activate nas-wot
|
||||
./reproduce.sh
|
||||
```
|
||||
|
||||
Will produce the following table:
|
||||
|
||||
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|
||||
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
|
||||
| Ours (N=100) | 18.35 | 89.18 +- 0.29 | 91.76 +- 1.28 | 67.17 +- 2.79 | 67.27 +- 2.68 | 40.84 +- 5.36 | 41.33 +- 5.74 |
|
||||
|
||||
The code is licensed under the MIT licence.
|
||||
|
@ -30,10 +30,13 @@ dependencies:
|
||||
- numpy-base=1.18.1=py38hde5b4d6_1
|
||||
- olefile=0.46=py_0
|
||||
- openssl=1.1.1g=h7b6447c_0
|
||||
- pandas=1.0.3=py38h0573a6f_0
|
||||
- pillow=7.1.2=py38hb39fc2d_0
|
||||
- pip=20.0.2=py38_3
|
||||
- python=3.8.3=hcff3b4d_0
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0
|
||||
- pytz=2020.1=py_0
|
||||
- readline=8.0=h7b6447c_0
|
||||
- setuptools=46.4.0=py38_0
|
||||
- six=1.14.0=py38_0
|
||||
@ -48,3 +51,5 @@ dependencies:
|
||||
- pip:
|
||||
- argparse==1.4.0
|
||||
- nas-bench-201==1.3
|
||||
prefix: /home/jturner/miniconda3/envs/nas-wot
|
||||
|
||||
|
86
process_results.py
Normal file
86
process_results.py
Normal file
@ -0,0 +1,86 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
from collections import OrderedDict
|
||||
|
||||
import tabulate
|
||||
parser = argparse.ArgumentParser(description='Produce tables')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
|
||||
parser.add_argument('--n_samples', default=100, type=int, help='how many samples to take')
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
from statistics import mean, median, stdev as std
|
||||
|
||||
import torch
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
df = []
|
||||
|
||||
datasets = OrderedDict()
|
||||
|
||||
datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True)
|
||||
datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False)
|
||||
|
||||
### CIFAR-100
|
||||
datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False)
|
||||
datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False)
|
||||
|
||||
datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False)
|
||||
datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False)
|
||||
|
||||
|
||||
dataset_top1s = OrderedDict()
|
||||
dataset_top1s['Method'] = f"Ours (N={args.n_samples})"
|
||||
dataset_top1s['Search time (s)'] = np.nan
|
||||
|
||||
time = 0.
|
||||
|
||||
for dataset, params in datasets.items():
|
||||
top1s = []
|
||||
|
||||
dset = params[0]
|
||||
acc_type = 'accs' if 'test' in params[1] else 'val_accs'
|
||||
filename = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
|
||||
full_scores = torch.load(filename)
|
||||
if dataset == 'CIFAR-10 (test)':
|
||||
time = median(full_scores['times'])
|
||||
dataset_top1s['Search time (s)'] = time
|
||||
accs = []
|
||||
for n in range(args.n_runs):
|
||||
acc = full_scores[acc_type][n]
|
||||
accs.append(acc)
|
||||
dataset_top1s[dataset] = accs
|
||||
|
||||
df = pd.DataFrame(dataset_top1s)
|
||||
|
||||
df['CIFAR-10 (val)'] = f"{mean(df['CIFAR-10 (val)']):.2f} +- {std(df['CIFAR-10 (val)']):.2f}"
|
||||
df['CIFAR-10 (test)'] = f"{mean(df['CIFAR-10 (test)']):.2f} +- {std(df['CIFAR-10 (test)']):.2f}"
|
||||
|
||||
df['CIFAR-100 (val)'] = f"{mean(df['CIFAR-100 (val)']):.2f} +- {std(df['CIFAR-100 (val)']):.2f}"
|
||||
df['CIFAR-100 (test)'] = f"{mean(df['CIFAR-100 (test)']):.2f} +- {std(df['CIFAR-100 (test)']):.2f}"
|
||||
|
||||
df['ImageNet16-120 (val)'] = f"{mean(df['ImageNet16-120 (val)']):.2f} +- {std(df['ImageNet16-120 (val)']):.2f}"
|
||||
df['ImageNet16-120 (test)'] = f"{mean(df['ImageNet16-120 (test)']):.2f} +- {std(df['ImageNet16-120 (test)']):.2f}"
|
||||
|
||||
df = df.round(2)
|
||||
df = df.iloc[:1]
|
||||
print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe"))
|
10
reproduce.sh
Normal file → Executable file
10
reproduce.sh
Normal file → Executable file
@ -1,4 +1,6 @@
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10'
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10'
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100'
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16'
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3
|
||||
|
||||
python process_results.py --n_runs 3
|
||||
|
@ -153,5 +153,5 @@ state = {'accs': acc,
|
||||
}
|
||||
|
||||
dset = args.dataset if not args.trainval else 'cifar10-valid'
|
||||
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.mc_samples}_{args.alpha}_{args.seed}.t7"
|
||||
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
torch.save(state, fname)
|
||||
|
Loading…
Reference in New Issue
Block a user