Update docs of NATS-Bench

This commit is contained in:
D-X-Y 2020-09-16 09:25:19 +00:00
parent 4aaf431ede
commit 386184fef8
2 changed files with 7 additions and 2 deletions

View File

@ -19,7 +19,7 @@ The structure of this Markdown file:
### Preparation and Download
The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing).
After download `NATS-[tss/sss]-[version]-[md5sum]-simple.tar`, please uncompress it by using `tar xvf [file_name]`.
We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple`) into `$TORCH_HOME`.
We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2` / `NATS-tss-v1_0-3ffb9.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple` / `NATS-tss-v1_0-3ffb9-simple`) into `$TORCH_HOME`.
In this way, our api will automatically find the path for these benchmarkfiles, which is convenient for the users. Otherwise, you need to manually indicate the file when creating the benchmark instance.
The history of benchmark files are as follows, `tss` indicates the topology search space and `sss` indicates the size search space.
@ -36,6 +36,7 @@ To merge the chunks into the original full archive, you can use `cat file_name*
1, create the benchmark instance:
```
from nats_bench import create
# Create the API instance for the size search space in NATS
api = create(None, 'sss', fast_mode=True, verbose=True)

View File

@ -5,7 +5,7 @@
##############################################################################
# Usage: python exps/NATS-Bench/test-nats-api.py #
##############################################################################
import os, sys, time, torch, argparse
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
@ -91,6 +91,8 @@ if __name__ == '__main__':
api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True)
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
@ -98,3 +100,5 @@ if __name__ == '__main__':
api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()