Upgrade API of NAS-Bench-201

This commit is contained in:
D-X-Y 2020-03-10 19:08:56 +11:00
parent c8f2a93ecf
commit d783193392
10 changed files with 623 additions and 178 deletions

View File

@ -30,18 +30,14 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1 CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
``` ```
### Searching on the NASNet search space
Please use the following scripts to use SETN to search as in the original paper:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-SETN.sh cifar10 1 -1
```
### Searching on the NAS-Bench-201 search space ### Searching on the NAS-Bench-201 search space
The searching codes of SETN on a small search space (NAS-Bench-201). The searching codes of SETN on a small search space (NAS-Bench-201).
``` ```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1
``` ```
**Searching on the NASNet search space** is not ready yet.
# Citation # Citation

View File

@ -21,9 +21,12 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w). The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
You can move it to anywhere you want and send its path to our API for initialization. You can move it to anywhere you want and send its path to our API for initialization.
- [2020.02.25] v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial. - [2020.02.25] v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
- [2020.02.25] v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights. - [2020.02.25] v1.0: The full data of each architecture can be download from [
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
- [2020.02.25] v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi). - [2020.02.25] v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
- [2020.03.08] v2.0: coming soon (results of two set of hyper-parameters avaliable on all three datasets) - [2020.03.09] v1.2: More robust API with more functions and descriptions
- [2020.04.01] v2.0: coming soon (results of two set of hyper-parameters avaliable on all three datasets)
The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ). The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-201 or similar NAS datasets or training models by yourself, you need these data. It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-201 or similar NAS datasets or training models by yourself, you need these data.

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
# python exps/NAS-Bench-201/check.py --base_save_dir # python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS
##################################################### #####################################################
import sys, time, argparse, collections import sys, time, argparse, collections
import torch import torch
@ -13,10 +13,9 @@ from log_utils import AverageMeter, time_string, convert_secs2time
def check_files(save_dir, meta_file, basestr): def check_files(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu') meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] meta_archs = meta_infos['archs']
meta_num_archs = meta_infos['total'] meta_num_archs = meta_infos['total']
meta_max_node = meta_infos['max_node']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
@ -43,7 +42,12 @@ def check_files(save_dir, meta_file, basestr):
dir2ckps, dir2ckp_exists = dict(), dict() dir2ckps, dir2ckp_exists = dict(), dict()
start_time, epoch_time = time.time(), AverageMeter() start_time, epoch_time = time.time(), AverageMeter()
for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
seeds = [777, 888, 999] if basestr == 'C16-N5':
seeds = [777, 888, 999]
elif basestr == 'C16-N5-LESS':
seeds = [111, 777]
else:
raise ValueError('Invalid base str : {:}'.format(basestr))
numrs = defaultdict(lambda: 0) numrs = defaultdict(lambda: 0)
all_checkpoints, all_ckp_exists = [], [] all_checkpoints, all_ckp_exists = [], []
for arch_index in arch_indexes: for arch_index in arch_indexes:
@ -66,17 +70,15 @@ def check_files(save_dir, meta_file, basestr):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.') parser.add_argument('--meta_path', type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.')
parser.add_argument('--channel', type=int, default=16, help='The number of channels.') parser.add_argument('--base_str', type=str, default='C16-N5', help='The basic string.')
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args() args = parser.parse_args()
save_dir = Path( args.base_save_dir ) save_dir = Path(args.base_save_dir)
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) meta_path = Path(args.meta_path)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
print ('check NAS-Bench-201 in {:}'.format(save_dir)) print ('check NAS-Bench-201 in {:}'.format(save_dir))
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) check_files(save_dir, meta_path, args.base_str)
check_files(save_dir, meta_path, basestr)

View File

@ -1,6 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
# [2020.03.09] Upgrade to v1.2
import os import os
from setuptools import setup from setuptools import setup
@ -12,7 +13,7 @@ def read(fname='README.md'):
setup( setup(
name = "nas_bench_201", name = "nas_bench_201",
version = "1.1", version = "1.2",
author = "Xuanyi Dong", author = "Xuanyi Dong",
author_email = "dongxuanyi888@gmail.com", author_email = "dongxuanyi888@gmail.com",
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).", description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",

View File

@ -0,0 +1,283 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import os, sys, time, argparse, collections
import numpy as np
import torch
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import dict2config
# NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import NASBench201API, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
api = NASBench201API('{:}/.torch/NAS-Bench-201-v1_0-e61699.pth'.firmat(os.environ['HOME']))
def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any],
results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount:
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(xresult.get_net_param())
if 'train_times' in results: # new version
xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
else:
if dataset == 'cifar10-valid':
xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda())
xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
elif dataset == 'cifar10':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_latency(latencies)
elif dataset == 'cifar100' or dataset == 'ImageNet16-120':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda())
xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
else:
raise ValueError('invalid dataset name : {:}'.format(dataset))
return xresult
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
datasets: List[Text], dataloader_dict: Dict[Text, Any]) -> ArchResults:
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path))
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
information.update(dataset, int(used_seed), xresult)
if ok_dataset == 0: raise ValueError('{:} does not find any data'.format(checkpoint_path))
return information
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', False) + api.get_latency(arch_index, 'cifar10', False)) / 2
arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency)
arch_info_full.reset_latency('cifar10', None, cifar010_latency)
arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency)
arch_info_less.reset_latency('cifar10', None, cifar010_latency)
cifar100_latency = api.get_latency(arch_index, 'cifar100', False)
arch_info_full.reset_latency('cifar100', None, cifar100_latency)
arch_info_less.reset_latency('cifar100', None, cifar100_latency)
image_latency = api.get_latency(arch_index, 'ImageNet16-120', False)
arch_info_full.reset_latency('ImageNet16-120', None, image_latency)
arch_info_less.reset_latency('ImageNet16-120', None, image_latency)
train_per_epoch_time = list(arch_info_less.query('cifar10-valid', 777).train_times.values())
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in arch_info_less.query('cifar10-valid', 777).eval_times.items():
if key.startswith('ori-test@'):
eval_ori_test_time.append(value)
elif key.startswith('x-valid@'):
eval_x_valid_time.append(value)
else: raise ValueError('-- {:} --'.format(key))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'ImageNet16-120-test': 6000,
'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000,
'cifar10-train': 50000, 'cifar10-test': 10000,
'cifar100-train': 50000, 'cifar100-test': 10000, 'cifar100-valid': 5000}
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums['cifar10-valid-valid'] + nums['cifar10-test'])
for arch_info in [arch_info_less, arch_info_full]:
arch_info.reset_pseudo_train_times('cifar10-valid', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-valid-train'])
arch_info.reset_pseudo_train_times('cifar10', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-train'])
arch_info.reset_pseudo_train_times('cifar100', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar100-train'])
arch_info.reset_pseudo_train_times('ImageNet16-120', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['ImageNet16-120-train'])
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_per_sample*nums['cifar10-valid-valid'])
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_per_sample * nums['cifar10-test'])
arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_per_sample * nums['cifar10-test'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_per_sample * nums['cifar100-valid'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_per_sample * nums['cifar100-valid'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_per_sample * nums['cifar100-test'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_per_sample * nums['ImageNet16-120-valid'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_per_sample * nums['ImageNet16-120-valid'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test'])
# arch_info_full.debug_test()
# arch_info_less.debug_test()
# import pdb; pdb.set_trace()
return arch_info_full, arch_info_less
def simplify(save_dir, meta_file, basestr, target_dir):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] # a list of architecture strings
meta_num_archs = meta_infos['total']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
num_seeds = defaultdict(lambda: 0)
for index, sub_dir in enumerate(sub_model_dirs):
xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
arch_indexes = set()
for checkpoint in xcheckpoints:
temp_names = checkpoint.name.split('-')
assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
arch_indexes.add( temp_names[1] )
subdir2archs[sub_dir] = sorted(list(arch_indexes))
num_evaluated_arch += len(arch_indexes)
# count number of seeds for each architecture
for arch_index in arch_indexes:
num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
print('{:} There are {:5d} architectures that have been evaluated ({:} in total).'.format(time_string(), num_evaluated_arch, meta_num_archs))
for key in sorted( list( num_seeds.keys() ) ): print ('{:} There are {:5d} architectures that are evaluated {:} times.'.format(time_string(), num_seeds[key], key))
dataloader_dict = get_nas_bench_loaders( 6 )
to_save_simply = save_dir / 'simplifies'
to_save_allarc = save_dir / 'simplifies' / 'architectures'
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True)
assert (save_dir / target_dir) in subdir2archs, 'can not find {:}'.format(target_dir)
arch2infos, datasets = {}, ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
evaluated_indexes = set()
target_full_dir = save_dir / target_dir
target_less_dir = save_dir / '{:}-LESS'.format(target_dir)
arch_indexes = subdir2archs[ target_full_dir ]
num_seeds = defaultdict(lambda: 0)
end_time = time.time()
arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_full_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))
ckps_less = list(target_less_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))
# create the arch info for each architecture
try:
arch_info_full = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
arch_info_less = account_one_arch(arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict)
num_seeds[ len(checkpoints) ] += 1
except:
print('Loading {:} failed, : {:}'.format(arch_index, checkpoints))
continue
assert int(arch_index) not in evaluated_indexes, 'conflict arch-index : {:}'.format(arch_index)
assert 0 <= int(arch_index) < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
arch_info = {'full': arch_info_full, 'less': arch_info_less}
evaluated_indexes.add(int(arch_index))
arch2infos[int(arch_index)] = arch_info
# to correct the latency and training_time info.
arch_info_full, arch_info_less = correct_time_related_info(int(arch_index), arch_info_full, arch_info_less)
to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict())
torch.save(to_save_data, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
arch_info['full'].clear_params()
arch_info['less'].clear_params()
torch.save(to_save_data, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = '{:}'.format( convert_secs2time(arch_time.avg * (len(arch_indexes)-idx-1), True) )
print('{:} {:} [{:03d}/{:03d}] : {:} still need {:}'.format(time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time))
# measure time
xstrs = ['{:}:{:03d}'.format(key, num_seeds[key]) for key in sorted( list( num_seeds.keys() ) ) ]
print('{:} {:} done : {:}'.format(time_string(), target_dir, xstrs))
final_infos = {'meta_archs' : meta_archs,
'total_archs': meta_num_archs,
'basestr' : basestr,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = to_save_simply / '{:}.pth'.format(target_dir)
torch.save(final_infos, save_file_name)
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
def merge_all(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs']
meta_num_archs = meta_infos['total']
assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
for index, sub_dir in enumerate(sub_model_dirs):
arch_info_files = sorted( list(sub_dir.glob('arch-*-seed-*.pth') ) )
print ('The {:02d}/{:02d}-th directory : {:} : {:} runs.'.format(index, len(sub_model_dirs), sub_dir, len(arch_info_files)))
arch2infos, evaluated_indexes = dict(), set()
for IDX, sub_dir in enumerate(sub_model_dirs):
ckp_path = sub_dir.parent / 'simplifies' / '{:}.pth'.format(sub_dir.name)
if ckp_path.exists():
sub_ckps = torch.load(ckp_path, map_location='cpu')
assert sub_ckps['total_archs'] == meta_num_archs and sub_ckps['basestr'] == basestr
xarch2infos = sub_ckps['arch2infos']
xevalindexs = sub_ckps['evaluated_indexes']
for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
#arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
arch2infos[eval_index] = {'full': xarch2infos[eval_index]['full'].state_dict(),
'less': xarch2infos[eval_index]['less'].state_dict()}
evaluated_indexes.add( eval_index )
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(sub_model_dirs), ckp_path, len(xevalindexs)))
else:
raise ValueError('Can not find {:}'.format(ckp_path))
#print ('{:} [{:03d}/{:03d}] can not find {:}, skip.'.format(time_string(), IDX, len(subdir2archs), ckp_path))
evaluated_indexes = sorted( list( evaluated_indexes ) )
print ('Finally, there are {:} architectures that have been trained and evaluated.'.format(len(evaluated_indexes)))
to_save_simply = save_dir / 'simplifies'
if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True)
final_infos = {'meta_archs' : meta_archs,
'total_archs': meta_num_archs,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = to_save_simply / '{:}-final-infos.pth'.format(basestr)
torch.save(final_infos, save_file_name)
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--target_dir' , type=str, help='The target directory.')
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args()
save_dir = Path(args.base_save_dir)
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
print ('start the statistics of our nas-benchmark from {:} using {:}.'.format(save_dir, args.target_dir))
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
if args.mode == 'cal':
simplify(save_dir, meta_path, basestr, args.target_dir)
elif args.mode == 'merge':
merge_all(save_dir, meta_path, basestr)
else:
raise ValueError('invalid mode : {:}'.format(args.mode))

View File

@ -4,7 +4,6 @@
import os, sys, time, argparse, collections import os, sys, time, argparse, collections
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
@ -15,8 +14,7 @@ from datasets import get_datasets
# NAS-Bench-201 related module or function # NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import ArchResults, ResultsCount from nas_201_api import ArchResults, ResultsCount
from functions import pure_evaluate from procedures import bench_pure_evaluate as pure_evaluate
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict): def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):
@ -69,7 +67,6 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
return information return information
def GET_DataLoaders(workers): def GET_DataLoaders(workers):
torch.set_num_threads(workers) torch.set_num_threads(workers)
@ -137,7 +134,6 @@ def GET_DataLoaders(workers):
return loaders return loaders
def simplify(save_dir, meta_file, basestr, target_dir): def simplify(save_dir, meta_file, basestr, target_dir):
meta_infos = torch.load(meta_file, map_location='cpu') meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] # a list of architecture strings meta_archs = meta_infos['archs'] # a list of architecture strings
@ -221,7 +217,6 @@ def simplify(save_dir, meta_file, basestr, target_dir):
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
def merge_all(save_dir, meta_file, basestr): def merge_all(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu') meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] meta_archs = meta_infos['archs']
@ -268,7 +263,6 @@ def merge_all(save_dir, meta_file, basestr):
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -280,7 +274,7 @@ if __name__ == '__main__':
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.') parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args() args = parser.parse_args()
save_dir = Path( args.base_save_dir ) save_dir = Path(args.base_save_dir)
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir) assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
@ -292,4 +286,4 @@ if __name__ == '__main__':
elif args.mode == 'merge': elif args.mode == 'merge':
merge_all(save_dir, meta_path, basestr) merge_all(save_dir, meta_path, basestr)
else: else:
raise ValueError('invalid mode : {:}'.format(args.mode)) raise ValueError('invalid mode : {:}'.format(args.mode))

View File

@ -4,4 +4,5 @@
from .api import NASBench201API from .api import NASBench201API
from .api import ArchResults, ResultsCount from .api import ArchResults, ResultsCount
NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25] # NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]

View File

@ -8,7 +8,7 @@
# #
# #
import os, copy, random, torch, numpy as np import os, copy, random, torch, numpy as np
from typing import List, Text, Union, Dict, Any from typing import List, Text, Union, Dict
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
@ -19,8 +19,7 @@ def print_information(information, extra_info=None, show=False):
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc) return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
for ida, dataset in enumerate(dataset_names): for ida, dataset in enumerate(dataset_names):
#flop, param, latency = information.get_comput_costs(dataset) metric = information.get_compute_costs(dataset)
metric = information.get_comput_costs(dataset)
flop, param, latency = metric['flops'], metric['params'], metric['latency'] flop, param, latency = metric['flops'], metric['params'], metric['latency']
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None) str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
train_info = information.get_metrics(dataset, 'train') train_info = information.get_metrics(dataset, 'train')
@ -80,6 +79,7 @@ class NASBench201API(object):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def random(self): def random(self):
"""Return a random index of all architectures."""
return random.randint(0, len(self.meta_archs)-1) return random.randint(0, len(self.meta_archs)-1)
# This function is used to query the index of an architecture in the search space. # This function is used to query the index of an architecture in the search space.
@ -166,7 +166,7 @@ class NASBench201API(object):
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes): for i, idx in enumerate(self.evaluated_indexes):
info = arch2infos[idx].get_comput_costs(dataset) info = arch2infos[idx].get_compute_costs(dataset)
flop, param, latency = info['flops'], info['params'], info['latency'] flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue if Param_max is not None and param > Param_max: continue
@ -178,38 +178,40 @@ class NASBench201API(object):
best_index, highest_accuracy = idx, accuracy best_index, highest_accuracy = idx, accuracy
return best_index, highest_accuracy return best_index, highest_accuracy
# return the topology structure of the `index`-th architecture
def arch(self, index: int): def arch(self, index: int):
"""Return the topology structure of the `index`-th architecture."""
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index]) return copy.deepcopy(self.meta_archs[index])
"""
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
Args [seed]:
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [use_12epochs_result]:
-- True : train the model by 12 epochs
-- False : train the model by 200 epochs
"""
def get_net_param(self, index, dataset, seed, use_12epochs_result=False): def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less """
else : basestr, arch2infos = '200epochs', self.arch2infos_full This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
archresult = arch2infos[index] Args [seed]:
return archresult.get_net_param(dataset, seed) -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
-- a interger : return the weights of a specific trial, whose seed is this interger.
Args [use_12epochs_result]:
-- True : train the model by 12 epochs
-- False : train the model by 200 epochs
"""
if use_12epochs_result: arch2infos = self.arch2infos_less
else: arch2infos = self.arch2infos_full
arch_result = arch2infos[index]
return arch_result.get_net_param(dataset, seed)
"""
This function is used to obtain the configuration for the `index`-th architecture on `dataset`. def get_net_config(self, index: int, dataset: Text):
Args [dataset] (4 possible options): """
-- cifar10-valid : training the model on the CIFAR-10 training set. This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
-- cifar10 : training the model on the CIFAR-10 training + validation set. Args [dataset] (4 possible options):
-- cifar100 : training the model on the CIFAR-100 training set. -- cifar10-valid : training the model on the CIFAR-10 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set. -- cifar10 : training the model on the CIFAR-10 training + validation set.
This function will return a dict. -- cifar100 : training the model on the CIFAR-100 training set.
========= Some examlpes for using this function: -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
config = api.get_net_config(128, 'cifar10') This function will return a dict.
""" ========= Some examlpes for using this function:
def get_net_config(self, index, dataset): config = api.get_net_config(128, 'cifar10')
"""
archresult = self.arch2infos_full[index] archresult = self.arch2infos_full[index]
all_results = archresult.query(dataset, None) all_results = archresult.query(dataset, None)
if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset)) if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset))
@ -218,12 +220,25 @@ class NASBench201API(object):
#print ('SEED [{:}] : {:}'.format(seed, result)) #print ('SEED [{:}] : {:}'.format(seed, result))
raise ValueError('Impossible to reach here!') raise ValueError('Impossible to reach here!')
# obtain the cost metric for the `index`-th architecture on a dataset
def get_cost_info(self, index, dataset, use_12epochs_result=False): def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]:
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less """To obtain the cost metric for the `index`-th architecture on a dataset."""
else : basestr, arch2infos = '200epochs', self.arch2infos_full if use_12epochs_result: arch2infos = self.arch2infos_less
archresult = arch2infos[index] else: arch2infos = self.arch2infos_full
return archresult.get_comput_costs(dataset) arch_result = arch2infos[index]
return arch_result.get_compute_costs(dataset)
def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float:
"""
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
:param index: the index of the target architecture
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
:return: return a float value in seconds
"""
cost_dict = self.get_cost_info(index, dataset, use_12epochs_result)
return cost_dict['latency']
# obtain the metric for the `index`-th architecture # obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset: # `dataset` indicates the dataset:
@ -298,12 +313,15 @@ class NASBench201API(object):
xifo['est-valid-accuracy'] = est_valid_info['accuracy'] xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo return xifo
"""
This function will print the information of a specific (or all) architecture(s).
If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture.
"""
def show(self, index: int = -1) -> None: def show(self, index: int = -1) -> None:
"""
This function will print the information of a specific (or all) architecture(s).
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture.
:return: nothing
"""
if index < 0: # show all architectures if index < 0: # show all architectures
print(self) print(self)
for i, idx in enumerate(self.evaluated_indexes): for i, idx in enumerate(self.evaluated_indexes):
@ -330,19 +348,27 @@ class NASBench201API(object):
else: else:
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
# This func shows how to read the string-based architecture encoding
# the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
# Usage:
# arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
# print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
# for i, node in enumerate(arch):
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
@staticmethod @staticmethod
def str2lists(xstr: Text) -> List[Any]: def str2lists(arch_str: Text) -> List[tuple]:
# assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) """
nodestrs = xstr.split('+') This function shows how to read the string-based architecture encoding.
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
:param
arch_str: the input is a string indicates the architecture topology, such as
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
:usage
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
for i, node in enumerate(arch):
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
"""
node_strs = arch_str.split('+')
genotypes = [] genotypes = []
for i, node_str in enumerate(nodestrs): for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs ) inputs = ( xi.split('~') for xi in inputs )
@ -350,40 +376,47 @@ class NASBench201API(object):
genotypes.append( input_infos ) genotypes.append( input_infos )
return genotypes return genotypes
# This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101
# Usage:
# # this will return a numpy matrix (2-D np.array)
# matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
# # This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
# [ [0, 0, 0, 0], # the first line represents the input (0-th) node
# [2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
# [0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
# [0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
# In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect'
# 2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
@staticmethod @staticmethod
def str2matrix(xstr): def str2matrix(arch_str: Text,
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
# this only support NAS-Bench-201 search space """
# this defination will be consistant with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24 This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
# If a node has two input-edges from the same node, this function does not work. One edge will be overleaped.
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] :param
nodestrs = xstr.split('+') arch_str: the input is a string indicates the architecture topology, such as
num_nodes = len(nodestrs) + 1 |nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
matrix = np.zeros((num_nodes,num_nodes)) search_space: a list of operation string, the default list is the search space for NAS-Bench-201
for i, node_str in enumerate(nodestrs): the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/models/cell_operations.py#L24
:return
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
:usage
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
:(NOTE)
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
"""
node_strs = arch_str.split('+')
num_nodes = len(node_strs) + 1
matrix = np.zeros((num_nodes, num_nodes))
for i, node_str in enumerate(node_strs):
inputs = list(filter(lambda x: x != '', node_str.split('|'))) inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
for xi in inputs: for xi in inputs:
op, idx = xi.split('~') op, idx = xi.split('~')
if op not in NAS_BENCH_201: raise ValueError('this op ({:}) is not in {:}'.format(op, NAS_BENCH_201)) if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
op_idx, node_idx = NAS_BENCH_201.index(op), int(idx) op_idx, node_idx = search_space.index(op), int(idx)
matrix[i+1, node_idx] = op_idx matrix[i+1, node_idx] = op_idx
return matrix return matrix
class ArchResults(object): class ArchResults(object):
def __init__(self, arch_index, arch_str): def __init__(self, arch_index, arch_str):
@ -393,15 +426,15 @@ class ArchResults(object):
self.dataset_seed = dict() self.dataset_seed = dict()
self.clear_net_done = False self.clear_net_done = False
def get_comput_costs(self, dataset): def get_compute_costs(self, dataset):
x_seeds = self.dataset_seed[dataset] x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
flops = [result.flop for result in results] flops = [result.flop for result in results]
params = [result.params for result in results] params = [result.params for result in results]
lantencies = [result.get_latency() for result in results] latencies = [result.get_latency() for result in results]
lantencies = [x for x in lantencies if x > 0] latencies = [x for x in latencies if x > 0]
mean_latency = np.mean(lantencies) if len(lantencies) > 0 else None mean_latency = np.mean(latencies) if len(latencies) > 0 else None
time_infos = defaultdict(list) time_infos = defaultdict(list)
for result in results: for result in results:
time_info = result.get_times() time_info = result.get_times()
@ -416,38 +449,38 @@ class ArchResults(object):
else: info[key] = None else: info[key] = None
return info return info
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
def get_metrics(self, dataset, setname, iepoch=None, is_random=False): def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
x_seeds = self.dataset_seed[dataset] x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
infos = defaultdict(list) infos = defaultdict(list)
@ -483,20 +516,55 @@ class ArchResults(object):
def get_dataset_seeds(self, dataset): def get_dataset_seeds(self, dataset):
return copy.deepcopy( self.dataset_seed[dataset] ) return copy.deepcopy( self.dataset_seed[dataset] )
""" def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
This function will return the trained network's weights on the 'dataset'. """
When the 'seed' is None, it will return the weights for every run trial in the form of a dict. This function will return the trained network's weights on the 'dataset'.
When the :arg
""" dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
def get_net_param(self, dataset, seed=None): seed: an integer indicates the seed value or None that indicates returing all trials.
"""
if seed is None: if seed is None:
x_seeds = self.dataset_seed[dataset] x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds} return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
else: else:
return self.all_results[(dataset, seed)].get_net_param() return self.all_results[(dataset, seed)].get_net_param()
# get the total number of training epochs def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].update_latency([latency])
else:
self.all_results[(dataset, seed)].update_latency([latency])
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
if seed is None:
for seed in self.dataset_seed[dataset]:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
else:
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
def get_latency(self, dataset: Text) -> float:
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
latencies = []
for seed in self.dataset_seed[dataset]:
latency = self.all_results[(dataset, seed)].get_latency()
if not isinstance(latency, float) or latency <= 0:
raise ValueError('invalid latency of {:} for {:} with {:}'.format(dataset))
latencies.append(latency)
return sum(latencies) / len(latencies)
def get_total_epoch(self, dataset=None): def get_total_epoch(self, dataset=None):
"""Return the total number of training epochs."""
if dataset is None: if dataset is None:
epochss = [] epochss = []
for xdata, x_seeds in self.dataset_seed.items(): for xdata, x_seeds in self.dataset_seed.items():
@ -509,13 +577,13 @@ class ArchResults(object):
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss)) if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
return epochss[-1] return epochss[-1]
# return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'
def query(self, dataset, seed=None): def query(self, dataset, seed=None):
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
if seed is None: if seed is None:
x_seeds = self.dataset_seed[dataset] x_seeds = self.dataset_seed[dataset]
return {seed: self.all_results[ (dataset, seed) ] for seed in x_seeds} return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
else: else:
return self.all_results[ (dataset, seed) ] return self.all_results[(dataset, seed)]
def arch_idx_str(self): def arch_idx_str(self):
return '{:06d}'.format(self.arch_index) return '{:06d}'.format(self.arch_index)
@ -573,7 +641,18 @@ class ArchResults(object):
def clear_params(self): def clear_params(self):
for key, result in self.all_results.items(): for key, result in self.all_results.items():
result.net_state_dict = None result.net_state_dict = None
self.clear_net_done = True self.clear_net_done = True
def debug_test(self):
"""This function is used for me to debug and test, which will call most methods."""
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
for dataset in all_dataset:
print('---->>>> {:}'.format(dataset))
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
for seed in self.dataset_seed[dataset]:
result = self.all_results[(dataset, seed)]
print(' ==>> result = {:}'.format(result))
print(' ==>> cost = {:}'.format(result.get_times()))
def __repr__(self): def __repr__(self):
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
@ -603,12 +682,25 @@ class ResultsCount(object):
# evaluation results # evaluation results
self.reset_eval() self.reset_eval()
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times): def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
self.train_acc1es = train_acc1es self.train_acc1es = train_acc1es
self.train_acc5es = train_acc5es self.train_acc5es = train_acc5es
self.train_losses = train_losses self.train_losses = train_losses
self.train_times = train_times self.train_times = train_times
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
"""Assign the training times."""
train_times = OrderedDict()
for i in range(self.epochs):
train_times[i] = estimated_per_epoch_time
self.train_times = train_times
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
"""Assign the evaluation times."""
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
for i in range(self.epochs):
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
def reset_eval(self): def reset_eval(self):
self.eval_names = [] self.eval_names = []
self.eval_acc1es = {} self.eval_acc1es = {}
@ -618,6 +710,11 @@ class ResultsCount(object):
def update_latency(self, latency): def update_latency(self, latency):
self.latency = copy.deepcopy( latency ) self.latency = copy.deepcopy( latency )
def get_latency(self) -> float:
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
if self.latency is None: return -1.0
else: return sum(self.latency) / len(self.latency)
def update_eval(self, accs, losses, times): # new version def update_eval(self, accs, losses, times): # new version
data_names = set([x.split('@')[0] for x in accs.keys()]) data_names = set([x.split('@')[0] for x in accs.keys()])
for data_name in data_names: for data_name in data_names:
@ -642,28 +739,22 @@ class ResultsCount(object):
set_name = '[' + ', '.join(self.eval_names) + ']' set_name = '[' + ', '.join(self.eval_names) + ']'
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name)) return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
# get the total number of training epochs
def get_total_epoch(self): def get_total_epoch(self):
return copy.deepcopy(self.epochs) return copy.deepcopy(self.epochs)
# get the latency
# -1 represents not avaliable ; otherwise it should be a float value
def get_latency(self):
if self.latency is None: return -1
else: return sum(self.latency) / len(self.latency)
# get the information regarding time
def get_times(self): def get_times(self):
"""Obtain the information regarding both training and evaluation time."""
if self.train_times is not None and isinstance(self.train_times, dict): if self.train_times is not None and isinstance(self.train_times, dict):
train_times = list( self.train_times.values() ) train_times = list( self.train_times.values() )
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)} time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
for name in self.eval_names: else:
time_info = {'T-train@epoch': None, 'T-train@total': None }
for name in self.eval_names:
try:
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)] xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes) time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
time_info['T-{:}@total'.format(name)] = np.sum(xtimes) time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
else: except:
time_info = {'T-train@epoch': None, 'T-train@total': None }
for name in self.eval_names:
time_info['T-{:}@epoch'.format(name)] = None time_info['T-{:}@epoch'.format(name)] = None
time_info['T-{:}@total'.format(name)] = None time_info['T-{:}@total'.format(name)] = None
return time_info return time_info
@ -699,18 +790,19 @@ class ResultsCount(object):
'cur_time': xtime, 'cur_time': xtime,
'all_time': atime} 'all_time': atime}
def get_net_param(self): def get_net_param(self, clone=False):
return self.net_state_dict if clone: return copy.deepcopy(self.net_state_dict)
else: return self.net_state_dict
# This function is used to obtain the config dict for this architecture. # This function is used to obtain the config dict for this architecture.
def get_config(self, str2structure): def get_config(self, str2structure):
if str2structure is None: if str2structure is None:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'], \ 'N' : self.arch_config['num_cells'],
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']} 'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
else: else:
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
'N' : self.arch_config['num_cells'], \ 'N' : self.arch_config['num_cells'],
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} 'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
def state_dict(self): def state_dict(self):

View File

@ -5,6 +5,7 @@ from .starts import prepare_seed, prepare_logger, get_machine_info, save_che
from .optimizers import get_optim_scheduler from .optimizers import get_optim_scheduler
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure): def get_procedures(procedure):
from .basic_main import basic_train, basic_valid from .basic_main import basic_train, basic_valid

View File

@ -1,14 +1,17 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
##################################################### #####################################################
import time, torch import os, time, copy, torch, pathlib
import datasets
from config_utils import load_config
from procedures import prepare_seed, get_optim_scheduler from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate'] __all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
@ -127,3 +130,72 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed
'finish-train': True 'finish-train': True
} }
return info_seed return info_seed
def get_nas_bench_loaders(workers):
torch.set_num_threads(workers)
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve()
torch_dir = pathlib.Path(os.environ['TORCH_HOME'])
# cifar
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
cifar_config = load_config(cifar_config_path, None, None)
get_datasets = datasets.get_datasets # a function to return the dataset
break_line = '-' * 150
print ('{:} Create data-loader for all datasets'.format(time_string()))
print (break_line)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
print (break_line)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
print (break_line)
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
imagenet16_config = load_config(imagenet16_config_path, None, None)
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
# 'cifar10', 'cifar100', 'ImageNet16-120'
loaders = {'cifar10@trainval': trainval_cifar10_loader,
'cifar10@train' : train_cifar10_loader,
'cifar10@valid' : valid_cifar10_loader,
'cifar10@test' : test__cifar10_loader,
'cifar100@train' : train_cifar100_loader,
'cifar100@valid' : valid_cifar100_loader,
'cifar100@test' : test__cifar100_loader,
'ImageNet16-120@train': train_imagenet_loader,
'ImageNet16-120@valid': valid_imagenet_loader,
'ImageNet16-120@test' : test__imagenet_loader}
return loaders