Remove call to .eval(), update results in README, ignore .t7 files

This commit is contained in:
jack-willturner 2020-07-03 12:19:44 +01:00
parent 07a991b3b1
commit 8a593d7f3f
3 changed files with 8 additions and 7 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pth
__pycache__
*.t7

View File

@ -24,15 +24,16 @@ Each command will finish by calling `process_results.py`, which will print a tab
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
| Ours (N=10) | 1.73435 | 89.25 +- 0.08 | 92.21 +- 0.11 | 68.53 +- 0.17 | 68.40 +- 0.14 | 40.42 +- 1.15 | 40.66 +- 0.97 |
| Ours (N=100) | 17.4139 | 89.18 +- 0.29 | 91.76 +- 1.28 | 67.17 +- 2.79 | 67.27 +- 2.68 | 40.84 +- 5.36 | 41.33 +- 5.74
| Ours (N=10) | 1.75 | 89.50 +- 0.51 | 92.98 +- 0.82 | 69.80 +- 2.46 | 69.86 +- 2.21 | 42.35 +- 1.19 | 42.38 +- 1.37 |
| Ours (N=100) | 17.76 | 87.44 +- 1.45 | 92.27 +- 1.53 | 70.26 +- 1.09 | 69.86 +- 0.60 | 43.30 +- 1.62 | 43.51 +- 1.40
`./reproduce 500` will produce the following table (which is the same as what we report in the paper):
`./reproduce 500` will produce the following table:
| Method | Search time (s) | CIFAR-10 (val) | CIFAR-10 (test) | CIFAR-100 (val) | CIFAR-100 (test) | ImageNet16-120 (val) | ImageNet16-120 (test) |
|:-------------|------------------:|:-----------------|:------------------|:------------------|:-------------------|:-----------------------|:------------------------|
| Ours (N=10) | 1.73435 | 88.47 +- 1.33 | 91.53 +- 1.62 | 66.49 +- 3.08 | 66.63 +- 3.14 | 38.33 +- 4.98 | 38.33 +- 5.22 |
| Ours (N=100) | 17.4139 | 88.45 +- 1.46 | 91.61 +- 1.71 | 66.42 +- 3.27 | 66.56 +- 3.28 | 36.56 +- 6.70 | 36.37 +- 6.97
| Ours (N=10) | 1.67 | 88.61 +- 1.58 | 91.58 +- 1.70 | 67.03 +- 3.01 | 67.15 +- 3.08 | 39.74 +- 4.17 | 39.76 +- 4.39 |
| Ours (N=100) | 17.12 | 88.43 +- 1.67 | 91.24 +- 1.70 | 67.04 +- 2.91 | 67.12 +- 2.98 | 40.68 +- 3.41 | 40.67 +- 3.55 |
To try different sample sizes, simply change the `--n_samples` argument in the call to `search.py`, and update the list of sample sizes [this line](https://github.com/BayesWatch/nas-without-training/blob/master/process_results.py#L51) of `process_results.py`.

View File

@ -8,7 +8,7 @@ from statistics import mean
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
type=str, help='path to API')
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
parser.add_argument('--batch_size', default=256, type=int)
@ -116,7 +116,6 @@ for N in runs:
network = get_cell_based_tiny_net(config) # create the network from configuration
network = network.to(device)
network.eval()
jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()