Update NATS-Bench (tss version 0.9)

This commit is contained in:
D-X-Y 2020-09-02 07:34:12 +00:00
parent e04808c14e
commit 8d64afd4a3
5 changed files with 88 additions and 4 deletions

View File

@ -5,4 +5,4 @@
- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released.
- [2019.01.31] [13e908f] GDAS codes were publicly released.
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
- [2020.07.30] [ ] Create NATS-BENCH.
- [2020.08.30] [ ] Create NATS-BENCH.

View File

@ -20,6 +20,7 @@ The structure of this Markdown file:
### Preparation and Download
The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing).
After download `NATS-[tss/sss]-[version]-[md5sum]-simple.tar`, please uncompress it by using `tar xvf [file_name]`.
We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple`) into `$TORCH_HOME`.
In this way, our api will automatically find the path for these benchmarkfiles, which is convenient for the users. Otherwise, you need to manually indicate the file when creating the benchmark instance.
@ -27,10 +28,12 @@ The history of benchmark files are as follows, `tss` indicates the topology sear
The benchmark file is used when create the NATS-Bench instance with `fast_mode=False`.
The archive is used when `fast_mode=True`, where `archive` is a directory contains 15,625 files for tss or contains 32,768 files for sss. Each file contains all the information for a specific architecture candidate.
The `full archive` is similar to `archive`, while each file in `full archive` contains **the trained weights**.
Since the full archive is too large, we use `split -b 30G file_name file_name` to split it into multiple 30G chunks.
To merge the chunks into the original full archive, you can use `cat file_name* > file_name`.
| Date | benchmark file (tss) | archive (tss) | full archive (tss) | benchmark file (sss) | archive (sss) | full archive (sss) |
|:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:|
| 2020.08.31 | | | | NATS-sss-v1_0-50262.pickle.pbz2 | NATS-sss-v1_0-50262-simple | [xx]-full |
| 2020.08.31 | | | | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | NATS-sss-v1_0-50262-full |
1, create the benchmark instance:
@ -114,6 +117,7 @@ Four multi-trial based methods:
python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space sss
python ./exps/NATS-algos/bohb.py --dataset cifar100 --search_space sss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
Run Transformable Architecture Search (TAS), proposed in Network Pruning via Transformable Architecture Search, NeurIPS 2019

View File

@ -1,7 +1,7 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/sss-file-manager.py --mode check #
##############################################################################

View File

@ -0,0 +1,80 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# Usage: python exps/NATS-Bench/tss-file-manager.py --mode check #
##############################################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
seed2ckps = defaultdict(list)
miss2ckps = defaultdict(list)
for i in range(total):
for seed in possible_seeds:
path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed))
if os.path.exists(path):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
for seed, xlist in seed2ckps.items():
print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total))
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)['miss2ckps']
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
for s, t in s2t.items():
copyfile(s, t)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench (topology search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.')
parser.add_argument('--check_N', type=int, default=15625, help='For safety.')
# use for train the model
args = parser.parse_args()
possible_configs = ['12', '200']
possible_seedss = [[111, 777], [777, 888, 999]]
if args.mode == 'check':
for config, possible_seeds in zip(possible_configs, possible_seedss):
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
elif args.mode == 'copy':
for config in possible_configs:
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print('Do not find : {:}'.format(cur_meta_path))
else:
raise ValueError('invalid mode : {:}'.format(args.mode))

View File

@ -192,7 +192,7 @@ if __name__ == '__main__':
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, verbose=False)
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
print('save-dir : {:}'.format(args.save_dir))