With table generator
This commit is contained in:
		| @@ -13,4 +13,10 @@ conda activate nas-wot | |||||||
| ./reproduce.sh | ./reproduce.sh | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | Will produce the following table: | ||||||
|  |  | ||||||
|  | | Method       |   Search time (s) | CIFAR-10 (val)   | CIFAR-10 (test)   | CIFAR-100 (val)   | CIFAR-100 (test)   | ImageNet16-120 (val)   | ImageNet16-120 (test)   | | ||||||
|  | |:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------| | ||||||
|  | | Ours (N=100) |             18.35 | 89.18 +- 0.29    | 91.76 +- 1.28     | 67.17 +- 2.79     | 67.27 +- 2.68      | 40.84 +- 5.36          | 41.33 +- 5.74           | | ||||||
|  |  | ||||||
| The code is licensed under the MIT licence. | The code is licensed under the MIT licence. | ||||||
|   | |||||||
| @@ -30,10 +30,13 @@ dependencies: | |||||||
|   - numpy-base=1.18.1=py38hde5b4d6_1 |   - numpy-base=1.18.1=py38hde5b4d6_1 | ||||||
|   - olefile=0.46=py_0 |   - olefile=0.46=py_0 | ||||||
|   - openssl=1.1.1g=h7b6447c_0 |   - openssl=1.1.1g=h7b6447c_0 | ||||||
|  |   - pandas=1.0.3=py38h0573a6f_0 | ||||||
|   - pillow=7.1.2=py38hb39fc2d_0 |   - pillow=7.1.2=py38hb39fc2d_0 | ||||||
|   - pip=20.0.2=py38_3 |   - pip=20.0.2=py38_3 | ||||||
|   - python=3.8.3=hcff3b4d_0 |   - python=3.8.3=hcff3b4d_0 | ||||||
|  |   - python-dateutil=2.8.1=py_0 | ||||||
|   - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 |   - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 | ||||||
|  |   - pytz=2020.1=py_0 | ||||||
|   - readline=8.0=h7b6447c_0 |   - readline=8.0=h7b6447c_0 | ||||||
|   - setuptools=46.4.0=py38_0 |   - setuptools=46.4.0=py38_0 | ||||||
|   - six=1.14.0=py38_0 |   - six=1.14.0=py38_0 | ||||||
| @@ -48,3 +51,5 @@ dependencies: | |||||||
|   - pip: |   - pip: | ||||||
|     - argparse==1.4.0 |     - argparse==1.4.0 | ||||||
|     - nas-bench-201==1.3 |     - nas-bench-201==1.3 | ||||||
|  | prefix: /home/jturner/miniconda3/envs/nas-wot | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										86
									
								
								process_results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								process_results.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | |||||||
|  | import numpy as np | ||||||
|  | import argparse | ||||||
|  | import os | ||||||
|  | import random | ||||||
|  | import pandas as pd | ||||||
|  | from collections import OrderedDict | ||||||
|  |  | ||||||
|  | import tabulate | ||||||
|  | parser = argparse.ArgumentParser(description='Produce tables') | ||||||
|  | parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder') | ||||||
|  | parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') | ||||||
|  |  | ||||||
|  | parser.add_argument('--batch_size', default=256, type=int) | ||||||
|  | parser.add_argument('--GPU', default='0', type=str) | ||||||
|  |  | ||||||
|  | parser.add_argument('--seed', default=1, type=int) | ||||||
|  | parser.add_argument('--trainval', action='store_true') | ||||||
|  |  | ||||||
|  | parser.add_argument('--n_samples', default=100, type=int, help='how many samples to take') | ||||||
|  | parser.add_argument('--n_runs', default=500, type=int) | ||||||
|  |  | ||||||
|  | args = parser.parse_args() | ||||||
|  | os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU | ||||||
|  |  | ||||||
|  | from statistics import mean, median, stdev as std | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | torch.backends.cudnn.deterministic = True | ||||||
|  | torch.backends.cudnn.benchmark = False | ||||||
|  | random.seed(args.seed) | ||||||
|  | np.random.seed(args.seed) | ||||||
|  | torch.manual_seed(args.seed) | ||||||
|  |  | ||||||
|  | df = [] | ||||||
|  |  | ||||||
|  | datasets = OrderedDict() | ||||||
|  |  | ||||||
|  | datasets['CIFAR-10 (val)'] = ('cifar10-valid', 'x-valid', True) | ||||||
|  | datasets['CIFAR-10 (test)'] = ('cifar10', 'ori-test', False) | ||||||
|  |  | ||||||
|  | ### CIFAR-100 | ||||||
|  | datasets['CIFAR-100 (val)'] = ('cifar100', 'x-valid', False) | ||||||
|  | datasets['CIFAR-100 (test)'] = ('cifar100', 'x-test', False) | ||||||
|  |  | ||||||
|  | datasets['ImageNet16-120 (val)'] = ('ImageNet16-120', 'x-valid', False) | ||||||
|  | datasets['ImageNet16-120 (test)'] = ('ImageNet16-120', 'x-test', False) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | dataset_top1s = OrderedDict() | ||||||
|  | dataset_top1s['Method'] = f"Ours (N={args.n_samples})" | ||||||
|  | dataset_top1s['Search time (s)'] = np.nan | ||||||
|  |  | ||||||
|  | time = 0. | ||||||
|  |  | ||||||
|  | for dataset, params in datasets.items(): | ||||||
|  |     top1s = [] | ||||||
|  |  | ||||||
|  |     dset =  params[0] | ||||||
|  |     acc_type = 'accs' if 'test' in params[1] else 'val_accs' | ||||||
|  |     filename = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" | ||||||
|  |  | ||||||
|  |     full_scores = torch.load(filename) | ||||||
|  |     if dataset == 'CIFAR-10 (test)': | ||||||
|  |         time = median(full_scores['times']) | ||||||
|  |         dataset_top1s['Search time (s)'] = time | ||||||
|  |     accs = [] | ||||||
|  |     for n in range(args.n_runs): | ||||||
|  |         acc = full_scores[acc_type][n] | ||||||
|  |         accs.append(acc) | ||||||
|  |     dataset_top1s[dataset] = accs | ||||||
|  |  | ||||||
|  | df = pd.DataFrame(dataset_top1s) | ||||||
|  |  | ||||||
|  | df['CIFAR-10 (val)']  = f"{mean(df['CIFAR-10 (val)']):.2f} +- {std(df['CIFAR-10 (val)']):.2f}" | ||||||
|  | df['CIFAR-10 (test)']  = f"{mean(df['CIFAR-10 (test)']):.2f} +- {std(df['CIFAR-10 (test)']):.2f}" | ||||||
|  |  | ||||||
|  | df['CIFAR-100 (val)'] = f"{mean(df['CIFAR-100 (val)']):.2f} +- {std(df['CIFAR-100 (val)']):.2f}" | ||||||
|  | df['CIFAR-100 (test)'] = f"{mean(df['CIFAR-100 (test)']):.2f} +- {std(df['CIFAR-100 (test)']):.2f}" | ||||||
|  |  | ||||||
|  | df['ImageNet16-120 (val)']  = f"{mean(df['ImageNet16-120 (val)']):.2f} +- {std(df['ImageNet16-120 (val)']):.2f}" | ||||||
|  | df['ImageNet16-120 (test)'] = f"{mean(df['ImageNet16-120 (test)']):.2f} +- {std(df['ImageNet16-120 (test)']):.2f}" | ||||||
|  |  | ||||||
|  | df = df.round(2) | ||||||
|  | df = df.iloc[:1] | ||||||
|  | print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe")) | ||||||
							
								
								
									
										10
									
								
								reproduce.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										10
									
								
								reproduce.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							| @@ -1,4 +1,6 @@ | |||||||
| python search.py --dataset cifar10 --data_loc '../datasets/cifar10' | python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 | ||||||
| python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' | python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 | ||||||
| python search.py --dataset cifar100 --data_loc '../datasets/cifar100' | python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 | ||||||
| python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' | python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 | ||||||
|  |  | ||||||
|  | python process_results.py --n_runs 3 | ||||||
|   | |||||||
| @@ -153,5 +153,5 @@ state = {'accs': acc, | |||||||
|          } |          } | ||||||
|  |  | ||||||
| dset = args.dataset if not args.trainval else 'cifar10-valid' | dset = args.dataset if not args.trainval else 'cifar10-valid' | ||||||
| fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.mc_samples}_{args.alpha}_{args.seed}.t7" | fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7" | ||||||
| torch.save(state, fname) | torch.save(state, fname) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user