diff --git a/README.md b/README.md index e9488f0..5e2df1d 100644 --- a/README.md +++ b/README.md @@ -13,4 +13,10 @@ conda activate nas-wot ./reproduce.sh ``` +Will produce the following table: + +| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) | +|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| +| Ours (N=100) | 18.35 | 89.18 +- 0.29 | 91.76 +- 1.28 | 67.17 +- 2.79 | 67.27 +- 2.68 | 40.84 +- 5.36 | 41.33 +- 5.74 | + The code is licensed under the MIT licence. diff --git a/environment.yml b/environment.yml index df9968e..ccfa99d 100644 --- a/environment.yml +++ b/environment.yml @@ -30,10 +30,13 @@ dependencies: - numpy-base=1.18.1=py38hde5b4d6_1 - olefile=0.46=py_0 - openssl=1.1.1g=h7b6447c_0 + - pandas=1.0.3=py38h0573a6f_0 - pillow=7.1.2=py38hb39fc2d_0 - pip=20.0.2=py38_3 - python=3.8.3=hcff3b4d_0 + - python-dateutil=2.8.1=py_0 - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 + - pytz=2020.1=py_0 - readline=8.0=h7b6447c_0 - setuptools=46.4.0=py38_0 - six=1.14.0=py38_0 @@ -48,3 +51,5 @@ dependencies: - pip: - argparse==1.4.0 - nas-bench-201==1.3 +prefix: /home/jturner/miniconda3/envs/nas-wot + diff --git a/process_results.py b/process_results.py new file mode 100644 index 0000000..a58422c --- /dev/null +++ b/process_results.py @@ -0,0 +1,86 @@ +import numpy as np +import argparse +import os +import random +import pandas as pd +from collections import OrderedDict + +import tabulate +parser = argparse.ArgumentParser(description='Produce tables') +parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder') +parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') + +parser.add_argument('--batch_size', default=256, type=int) +parser.add_argument('--GPU', default='0', type=str) + +parser.add_argument('--seed', default=1, type=int) +parser.add_argument('--trainval', action='store_true') + +parser.add_argument('--n_samples', default=100, type=int, help='how many samples to take') +parser.add_argument('--n_runs', default=500, type=int) + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU + +from statistics import mean, median, stdev as std + +import torch + +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +random.seed(args.seed) +np.random.seed(args.seed) +torch.manual_seed(args.seed) + +df = [] + +datasets = OrderedDict() + +datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True) +datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False) + +### CIFAR-100 +datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False) +datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False) + +datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False) +datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False) + + +dataset_top1s = OrderedDict() +dataset_top1s['Method'] = f"Ours (N={args.n_samples})" +dataset_top1s['Search time (s)'] = np.nan + +time = 0. + +for dataset, params in datasets.items(): + top1s = [] + + dset = params[0] + acc_type = 'accs' if 'test' in params[1] else 'val_accs' + filename = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" + + full_scores = torch.load(filename) + if dataset == 'CIFAR-10 (test)': + time = median(full_scores['times']) + dataset_top1s['Search time (s)'] = time + accs = [] + for n in range(args.n_runs): + acc = full_scores[acc_type][n] + accs.append(acc) + dataset_top1s[dataset] = accs + +df = pd.DataFrame(dataset_top1s) + +df['CIFAR-10 (val)'] = f"{mean(df['CIFAR-10 (val)']):.2f} +- {std(df['CIFAR-10 (val)']):.2f}" +df['CIFAR-10 (test)'] = f"{mean(df['CIFAR-10 (test)']):.2f} +- {std(df['CIFAR-10 (test)']):.2f}" + +df['CIFAR-100 (val)'] = f"{mean(df['CIFAR-100 (val)']):.2f} +- {std(df['CIFAR-100 (val)']):.2f}" +df['CIFAR-100 (test)'] = f"{mean(df['CIFAR-100 (test)']):.2f} +- {std(df['CIFAR-100 (test)']):.2f}" + +df['ImageNet16-120 (val)'] = f"{mean(df['ImageNet16-120 (val)']):.2f} +- {std(df['ImageNet16-120 (val)']):.2f}" +df['ImageNet16-120 (test)'] = f"{mean(df['ImageNet16-120 (test)']):.2f} +- {std(df['ImageNet16-120 (test)']):.2f}" + +df = df.round(2) +df = df.iloc[:1] +print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe")) diff --git a/reproduce.sh b/reproduce.sh old mode 100644 new mode 100755 index d23c495..fdf99dd --- a/reproduce.sh +++ b/reproduce.sh @@ -1,4 +1,6 @@ -python search.py --dataset cifar10 --data_loc '../datasets/cifar10' -python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' -python search.py --dataset cifar100 --data_loc '../datasets/cifar100' -python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' +python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 +python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 +python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 +python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 + +python process_results.py --n_runs 3 diff --git a/search.py b/search.py index c523720..6904765 100644 --- a/search.py +++ b/search.py @@ -153,5 +153,5 @@ state = {'accs': acc, } dset = args.dataset if not args.trainval else 'cifar10-valid' -fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.mc_samples}_{args.alpha}_{args.seed}.t7" +fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" torch.save(state, fname)