Update NATS-Bench (tss version 0.9)
This commit is contained in:
parent
e04808c14e
commit
8d64afd4a3
@ -5,4 +5,4 @@
|
||||
- [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released.
|
||||
- [2019.01.31] [13e908f] GDAS codes were publicly released.
|
||||
- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version.
|
||||
- [2020.07.30] [ ] Create NATS-BENCH.
|
||||
- [2020.08.30] [ ] Create NATS-BENCH.
|
||||
|
@ -20,6 +20,7 @@ The structure of this Markdown file:
|
||||
|
||||
### Preparation and Download
|
||||
The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing).
|
||||
After download `NATS-[tss/sss]-[version]-[md5sum]-simple.tar`, please uncompress it by using `tar xvf [file_name]`.
|
||||
We highly recommend to put the downloaded benchmark file (`NATS-sss-v1_0-50262.pickle.pbz2`) or uncompressed archive (`NATS-sss-v1_0-50262-simple`) into `$TORCH_HOME`.
|
||||
In this way, our api will automatically find the path for these benchmarkfiles, which is convenient for the users. Otherwise, you need to manually indicate the file when creating the benchmark instance.
|
||||
|
||||
@ -27,10 +28,12 @@ The history of benchmark files are as follows, `tss` indicates the topology sear
|
||||
The benchmark file is used when create the NATS-Bench instance with `fast_mode=False`.
|
||||
The archive is used when `fast_mode=True`, where `archive` is a directory contains 15,625 files for tss or contains 32,768 files for sss. Each file contains all the information for a specific architecture candidate.
|
||||
The `full archive` is similar to `archive`, while each file in `full archive` contains **the trained weights**.
|
||||
Since the full archive is too large, we use `split -b 30G file_name file_name` to split it into multiple 30G chunks.
|
||||
To merge the chunks into the original full archive, you can use `cat file_name* > file_name`.
|
||||
|
||||
| Date | benchmark file (tss) | archive (tss) | full archive (tss) | benchmark file (sss) | archive (sss) | full archive (sss) |
|
||||
|:-----------|:---------------------:|:-------------:|:------------------:|:-------------------------------:|:--------------------------:|:------------------:|
|
||||
| 2020.08.31 | | | | NATS-sss-v1_0-50262.pickle.pbz2 | NATS-sss-v1_0-50262-simple | [xx]-full |
|
||||
| 2020.08.31 | | | | [NATS-sss-v1_0-50262.pickle.pbz2](https://drive.google.com/file/d/1IabIvzWeDdDAWICBzFtTCMXxYWPIOIOX/view?usp=sharing) | [NATS-sss-v1_0-50262-simple.tar](https://drive.google.com/file/d/1scOMTUwcQhAMa_IMedp9lTzwmgqHLGgA/view?usp=sharing) | NATS-sss-v1_0-50262-full |
|
||||
|
||||
|
||||
1, create the benchmark instance:
|
||||
@ -114,6 +117,7 @@ Four multi-trial based methods:
|
||||
python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
|
||||
python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||
python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space sss
|
||||
python ./exps/NATS-algos/bohb.py --dataset cifar100 --search_space sss --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
|
||||
|
||||
Run Transformable Architecture Search (TAS), proposed in Network Pruning via Transformable Architecture Search, NeurIPS 2019
|
||||
|
@ -1,7 +1,7 @@
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
|
||||
##############################################################################
|
||||
# Usage: python exps/NATS-Bench/sss-file-manager.py --mode check #
|
||||
##############################################################################
|
||||
|
80
exps/NATS-Bench/tss-file-manager.py
Normal file
80
exps/NATS-Bench/tss-file-manager.py
Normal file
@ -0,0 +1,80 @@
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
|
||||
##############################################################################
|
||||
# Usage: python exps/NATS-Bench/tss-file-manager.py --mode check #
|
||||
##############################################################################
|
||||
import os, sys, time, torch, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from procedures import bench_evaluate_for_seed
|
||||
from procedures import get_machine_info
|
||||
from datasets import get_datasets
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]):
|
||||
seed2ckps = defaultdict(list)
|
||||
miss2ckps = defaultdict(list)
|
||||
for i in range(total):
|
||||
for seed in possible_seeds:
|
||||
path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed))
|
||||
if os.path.exists(path):
|
||||
seed2ckps[seed].append(i)
|
||||
else:
|
||||
miss2ckps[seed].append(i)
|
||||
for seed, xlist in seed2ckps.items():
|
||||
print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total))
|
||||
return dict(seed2ckps), dict(miss2ckps)
|
||||
|
||||
|
||||
def copy_data(source_dir, target_dir, meta_path):
|
||||
target_dir = Path(target_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
miss2ckps = torch.load(meta_path)['miss2ckps']
|
||||
s2t = {}
|
||||
for seed, xlist in miss2ckps.items():
|
||||
for i in xlist:
|
||||
file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
|
||||
source_path = os.path.join(source_dir, file_name)
|
||||
target_path = os.path.join(target_dir, file_name)
|
||||
if os.path.exists(source_path):
|
||||
s2t[source_path] = target_path
|
||||
print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
|
||||
for s, t in s2t.items():
|
||||
copyfile(s, t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NATS-Bench (topology search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--check_N', type=int, default=15625, help='For safety.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
possible_configs = ['12', '200']
|
||||
possible_seedss = [[111, 777], [777, 888, 999]]
|
||||
if args.mode == 'check':
|
||||
for config, possible_seeds in zip(possible_configs, possible_seedss):
|
||||
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
|
||||
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds)
|
||||
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
|
||||
elif args.mode == 'copy':
|
||||
for config in possible_configs:
|
||||
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
|
||||
cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
|
||||
cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
|
||||
if os.path.exists(cur_meta_path):
|
||||
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
|
||||
else:
|
||||
print('Do not find : {:}'.format(cur_meta_path))
|
||||
else:
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
@ -192,7 +192,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, verbose=False)
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
Loading…
Reference in New Issue
Block a user