Final check
This commit is contained in:
parent
656e75cba8
commit
6fbaee7a6d
17
README.md
17
README.md
@ -10,22 +10,25 @@ To reproduce our results:
|
||||
conda env create -f environment.yml
|
||||
|
||||
conda activate nas-wot
|
||||
./reproduce.sh
|
||||
./reproduce.sh 3 # average accuracy over 3 runs
|
||||
./reproduce.sh 500 # average accuracy over 500 runs (this will take longer)
|
||||
```
|
||||
|
||||
For a quick run you can set `--n_runs 3` to get results after 3 runs:
|
||||
Each command will finish by calling `process_results.py`, which will print a table. `./reproduce.sh 3` should print the following table:
|
||||
|
||||
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|
||||
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
|
||||
| Ours (N=10) | 1.73435 | 88.99 $\pm$ 0.24 | 92.42 $\pm$ 0.33 | 67.86 $\pm$ 0.49 | 67.54 $\pm$ 0.75 | 41.16 $\pm$ 2.31 | 40.98 $\pm$ 2.72 |
|
||||
| Ours (N=100) | 17.4139 | 89.18 $\pm$ 0.29 | 91.76 $\pm$ 1.28 | 67.17 $\pm$ 2.79 | 67.27 $\pm$ 2.68 | 40.84 $\pm$ 5.36 | 41.33 $\pm$ 5.74
|
||||
| Ours (N=10) | 1.73435 | 88.99 +- 0.24 | 92.42 +- 0.33 | 67.86 +- 0.49 | 67.54 +- 0.75 | 41.16 +- 2.31 | 40.98 +- 2.72 |
|
||||
| Ours (N=100) | 17.4139 | 89.18 +- 0.29 | 91.76 +- 1.28 | 67.17 +- 2.79 | 67.27 +- 2.68 | 40.84 +- 5.36 | 41.33 +- 5.74
|
||||
|
||||
The size of `N` is set with `--n_samples 10`. To produce the results in the paper, set `--n_runs 500`:
|
||||
`./reproduce 500` will produce the following table (which is the same as what we report in the paper):
|
||||
|
||||
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|
||||
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
|
||||
| Ours (N=10) | 1.73435 | 89.25 $\pm$ 0.08 | 92.21 $\pm$ 0.11 | 68.53 $\pm$ 0.17 | 68.40 $\pm$ 0.14 | 40.42 $\pm$ 1.15 | 40.66 $\pm$ 0.97 |
|
||||
| Ours (N=100) | 17.4139 | 88.45 $\pm$ 1.46 | 91.61 $\pm$ 1.71 | 66.42 $\pm$ 3.27 | 66.56 $\pm$ 3.28 | 36.56 $\pm$ 6.70 | 36.37 $\pm$ 6.97
|
||||
| Ours (N=10) | 1.73435 | 89.25 +- 0.08 | 92.21 +- 0.11 | 68.53 +- 0.17 | 68.40 +- 0.14 | 40.42 +- 1.15 | 40.66 +- 0.97 |
|
||||
| Ours (N=100) | 17.4139 | 88.45 +- 1.46 | 91.61 +- 1.71 | 66.42 +- 3.27 | 66.56 +- 3.28 | 36.56 +- 6.70 | 36.37 +- 6.97
|
||||
|
||||
|
||||
To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes on line 51 of `process_results.py`.
|
||||
|
||||
The code is licensed under the MIT licence.
|
||||
|
@ -51,5 +51,4 @@ dependencies:
|
||||
- pip:
|
||||
- argparse==1.4.0
|
||||
- nas-bench-201==1.3
|
||||
prefix: /home/jturner/miniconda3/envs/nas-wot
|
||||
|
||||
- tabulate==0.8.7
|
||||
|
@ -63,25 +63,25 @@ for n_samples in [10, 100]:
|
||||
full_scores = torch.load(filename)
|
||||
if dataset == 'CIFAR-10 (test)':
|
||||
time = median(full_scores['times'])
|
||||
dataset_top1s['Search time (s)'] = time
|
||||
time = f"{time:.2f}"
|
||||
accs = []
|
||||
for n in range(args.n_runs):
|
||||
acc = full_scores[acc_type][n]
|
||||
accs.append(acc)
|
||||
dataset_top1s[dataset] = accs
|
||||
|
||||
cifar10_val = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-10 (val)']):.2f}"
|
||||
cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-10 (test)']):.2f}"
|
||||
cifar10_val = f"{mean(dataset_top1s['CIFAR-10 (val)']):.2f} +- {std(dataset_top1s['CIFAR-10 (val)']):.2f}"
|
||||
cifar10_test = f"{mean(dataset_top1s['CIFAR-10 (test)']):.2f} +- {std(dataset_top1s['CIFAR-10 (test)']):.2f}"
|
||||
|
||||
cifar100_val = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-100 (val)']):.2f}"
|
||||
cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} $\pm$ {std(dataset_top1s['CIFAR-100 (test)']):.2f}"
|
||||
cifar100_val = f"{mean(dataset_top1s['CIFAR-100 (val)']):.2f} +- {std(dataset_top1s['CIFAR-100 (val)']):.2f}"
|
||||
cifar100_test = f"{mean(dataset_top1s['CIFAR-100 (test)']):.2f} +- {std(dataset_top1s['CIFAR-100 (test)']):.2f}"
|
||||
|
||||
imagenet_val = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} $\pm$ {std(dataset_top1s['ImageNet16-120 (val)']):.2f}"
|
||||
imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} $\pm$ {std(dataset_top1s['ImageNet16-120 (test)']):.2f}"
|
||||
imagenet_val = f"{mean(dataset_top1s['ImageNet16-120 (val)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (val)']):.2f}"
|
||||
imagenet_test = f"{mean(dataset_top1s['ImageNet16-120 (test)']):.2f} +- {std(dataset_top1s['ImageNet16-120 (test)']):.2f}"
|
||||
|
||||
df.append([method, time, cifar10_val, cifar10_test, cifar100_val, cifar100_test, imagenet_val, imagenet_test])
|
||||
|
||||
|
||||
df = pd.DataFrame(df, columns=['Method','Search time (s)','CIFAR-10 (val)','CIFAR-10 (test)','CIFAR-100 (val)','CIFAR-100 (test)','ImageNet16-120 (val)','ImageNet16-120 (test)' ])
|
||||
df.round(2)
|
||||
|
||||
print(tabulate.tabulate(df.values,df.columns, tablefmt="pipe"))
|
||||
|
20
reproduce.sh
20
reproduce.sh
@ -1,11 +1,13 @@
|
||||
#python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 10
|
||||
#python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 10
|
||||
#python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 --n_samples 10
|
||||
#python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 --n_samples 10
|
||||
#!/bin/bash
|
||||
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 100
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs 3 --n_samples 100
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs 3 --n_samples 100
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs 3 --n_samples 100
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 10
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 10
|
||||
|
||||
python process_results.py --n_runs 3
|
||||
python search.py --dataset cifar10 --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset cifar10 --trainval --data_loc '../datasets/cifar10' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset cifar100 --data_loc '../datasets/cifar100' --n_runs $1 --n_samples 100
|
||||
python search.py --dataset ImageNet16-120 --data_loc '../datasets/ImageNet16' --n_runs $1 --n_samples 100
|
||||
|
||||
python process_results.py --n_runs $1
|
||||
|
Loading…
Reference in New Issue
Block a user