Update NATS-Bench (tss version 0.99)

This commit is contained in:
D-X-Y 2020-09-05 10:40:29 +00:00
parent 8d64afd4a3
commit bd9288f45d
9 changed files with 379 additions and 56 deletions

View File

@ -99,6 +99,12 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
If you find that this project helps your research, please consider citing some of the following papers:
```
@article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}
}
@inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},

View File

@ -99,6 +99,12 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
```
@inproceedings{dong2020nasbench201,
@article{dong2020nats,
title={NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size},
author={Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
journal={arXiv preprint arXiv:2009.00437},
year={2020}
}
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},

View File

@ -77,17 +77,17 @@ def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', False) + api.get_latency(arch_index, 'cifar10', False)) / 2
cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200') + api.get_latency(arch_index, 'cifar10', hp='200')) / 2
arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency)
arch_info_full.reset_latency('cifar10', None, cifar010_latency)
arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency)
arch_info_less.reset_latency('cifar10', None, cifar010_latency)
cifar100_latency = api.get_latency(arch_index, 'cifar100', False)
cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200')
arch_info_full.reset_latency('cifar100', None, cifar100_latency)
arch_info_less.reset_latency('cifar100', None, cifar100_latency)
image_latency = api.get_latency(arch_index, 'ImageNet16-120', False)
image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200')
arch_info_full.reset_latency('ImageNet16-120', None, image_latency)
arch_info_less.reset_latency('ImageNet16-120', None, image_latency)

View File

@ -1,7 +1,7 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-sss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
@ -25,6 +25,7 @@ from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
NATS_SSS_BASE_NAME = 'NATS-sss-v1_0' # 2020.08.28

View File

@ -85,13 +85,16 @@ def test_api(api, is_301=True):
if __name__ == '__main__':
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api201 = create(None, 'tss', fast_mode=fast_mode, verbose=True)
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
test_api(api201, False)
for fast_mode in [True, False]:
for verbose in [True, False]:
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api301))
test_api(api301, True)
# api201 = create(None, 'topology', True) # use the default file path
# test_api(api201, False)
# print ('Test {:} done'.format(api201))

View File

@ -0,0 +1,262 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
##############################################################################
# This file is used to re-orangize all checkpoints (created by main-tss.py) #
# into a single benchmark file. Besides, for each trial, we will merge the #
# information of all its trials into a single file. #
# #
# Usage: #
# python exps/NATS-Bench/tss-collect.py #
##############################################################################
import os, re, sys, time, random, argparse, collections
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets import get_datasets
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from nas_201_api import NASBench201API
api = NASBench201API('{:}/.torch/NAS-Bench-201-v1_0-e61699.pth'.format(os.environ['HOME']))
NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28
def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any],
results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount:
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None)
if 'train_times' in results: # new version
xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
else:
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(xresult.get_net_param())
if dataset == 'cifar10-valid':
xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda())
xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
elif dataset == 'cifar10':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_latency(latencies)
elif dataset == 'cifar100' or dataset == 'ImageNet16-120':
xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses'])
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda())
xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda())
xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss})
xresult.update_latency(latencies)
else:
raise ValueError('invalid dataset name : {:}'.format(dataset))
return xresult
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict):
information = ArchResults(arch_index, arch_str)
for checkpoint_path in checkpoints:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
ok_dataset = 0
for dataset in datasets:
if dataset not in checkpoint:
print('Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path))
continue
else:
ok_dataset += 1
results = checkpoint[dataset]
assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path)
arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']}
xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict)
information.update(dataset, int(used_seed), xresult)
if ok_dataset == 0: raise ValueError('{:} does not find any data'.format(checkpoint_path))
return information
def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]):
# calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200') + api.get_latency(arch_index, 'cifar10', hp='200')) / 2
cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200')
image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200')
for hp, arch_info in arch_infos.items():
arch_info.reset_latency('cifar10-valid', None, cifar010_latency)
arch_info.reset_latency('cifar10', None, cifar010_latency)
arch_info.reset_latency('cifar100', None, cifar100_latency)
arch_info.reset_latency('ImageNet16-120', None, image_latency)
train_per_epoch_time = list(arch_infos['12'].query('cifar10-valid', 777).train_times.values())
train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time)
eval_ori_test_time, eval_x_valid_time = [], []
for key, value in arch_infos['12'].query('cifar10-valid', 777).eval_times.items():
if key.startswith('ori-test@'):
eval_ori_test_time.append(value)
elif key.startswith('x-valid@'):
eval_x_valid_time.append(value)
else: raise ValueError('-- {:} --'.format(key))
eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'ImageNet16-120-test': 6000,
'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000,
'cifar10-train': 50000, 'cifar10-test': 10000,
'cifar100-train': 50000, 'cifar100-test': 10000, 'cifar100-valid': 5000}
eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums['cifar10-valid-valid'] + nums['cifar10-test'])
for hp, arch_info in arch_infos.items():
arch_info.reset_pseudo_train_times('cifar10-valid', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-valid-train'])
arch_info.reset_pseudo_train_times('cifar10', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-train'])
arch_info.reset_pseudo_train_times('cifar100', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar100-train'])
arch_info.reset_pseudo_train_times('ImageNet16-120', None,
train_per_epoch_time / nums['cifar10-valid-train'] * nums['ImageNet16-120-train'])
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'x-valid', eval_per_sample*nums['cifar10-valid-valid'])
arch_info.reset_pseudo_eval_times('cifar10-valid', None, 'ori-test', eval_per_sample * nums['cifar10-test'])
arch_info.reset_pseudo_eval_times('cifar10', None, 'ori-test', eval_per_sample * nums['cifar10-test'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-valid', eval_per_sample * nums['cifar100-valid'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'x-test', eval_per_sample * nums['cifar100-valid'])
arch_info.reset_pseudo_eval_times('cifar100', None, 'ori-test', eval_per_sample * nums['cifar100-test'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-valid', eval_per_sample * nums['ImageNet16-120-valid'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'x-test', eval_per_sample * nums['ImageNet16-120-valid'])
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test'])
return arch_infos
def simplify(save_dir, save_name, nets, total, sup_config):
dataloader_dict = get_nas_bench_loaders(6)
hps, seeds = ['12', '200'], set()
for hp in hps:
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
seed2names = defaultdict(list)
for ckp in ckps:
parts = re.split('-|\.', ckp.name)
seed2names[parts[3]].append(ckp.name)
print('DIR : {:}'.format(sub_save_dir))
nums = []
for seed, xlist in seed2names.items():
seeds.add(seed)
nums.append(len(xlist))
print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist)))
assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total)
print('{:} start simplify the checkpoint.'.format(time_string()))
datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')
# Create the directory to save the processed data
# full_save_dir contains all benchmark files with trained weights.
# simplify_save_dir contains all benchmark files without trained weights.
full_save_dir = save_dir / (save_name + '-FULL')
simple_save_dir = save_dir / (save_name + '-SIMPLIFY')
full_save_dir.mkdir(parents=True, exist_ok=True)
simple_save_dir.mkdir(parents=True, exist_ok=True)
# all data in memory
arch2infos, evaluated_indexes = dict(), set()
end_time, arch_time = time.time(), AverageMeter()
# save the meta information
temp_final_infos = {'meta_archs' : nets,
'total_archs': total,
'arch2infos' : None,
'evaluated_indexes': set()}
pickle_save(temp_final_infos, str(full_save_dir / 'meta.pickle'))
pickle_save(temp_final_infos, str(simple_save_dir / 'meta.pickle'))
for index in tqdm(range(total)):
arch_str = nets[index]
hp2info = OrderedDict()
full_save_path = full_save_dir / '{:06d}.pickle'.format(index)
simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
for hp in hps:
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds]
ckps = [x for x in ckps if x.exists()]
if len(ckps) == 0:
raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp))
arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict)
hp2info[hp] = arch_info
hp2info = correct_time_related_info(index, hp2info)
evaluated_indexes.add(index)
to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
'200': hp2info['200'].state_dict()})
pickle_save(to_save_data, str(full_save_path))
for hp in hps: hp2info[hp].clear_params()
to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
'200': hp2info['200'].state_dict()})
pickle_save(to_save_data, str(simple_save_path))
arch2infos[index] = to_save_data
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True))
# print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
print('{:} {:} done.'.format(time_string(), save_name))
final_infos = {'meta_archs' : nets,
'total_archs': total,
'arch2infos' : arch2infos,
'evaluated_indexes': evaluated_indexes}
save_file_name = save_dir / '{:}.pickle'.format(save_name)
pickle_save(final_infos, str(save_file_name))
# move the benchmark file to a new path
hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
# move the directory to a new path
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
# save the meta information for simple and full
# final_infos['arch2infos'] = None
# final_infos['evaluated_indexes'] = set()
def traverse_net(max_node):
aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench')
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
random.seed( 88 ) # please do not change this line for reproducibility
random.shuffle( archs )
assert archs[0 ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0])
assert archs[9 ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9])
assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123])
return [x.tostr() for x in archs]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench (topology search space)', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--base_save_dir', type=str, default='./output/NATS-Bench-topology', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
parser.add_argument('--check_N' , type=int, default=15625, help='For safety.')
parser.add_argument('--save_name' , type=str, default='process', help='The save directory.')
args = parser.parse_args()
nets = traverse_net(args.max_node)
if len(nets) != args.check_N:
raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
save_dir = Path(args.base_save_dir)
simplify(save_dir, args.save_name, nets, args.check_N, {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells})

View File

@ -10,6 +10,7 @@ import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
from .api_utils import time_string
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
@ -71,7 +72,7 @@ class NATSsize(NASBenchMetaAPI):
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose:
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
print('{:} Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
@ -116,14 +117,15 @@ class NATSsize(NASBenchMetaAPI):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[arch] = idx
if self.verbose:
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
print('{:} Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
If index is None, overwrite all ckps.
"""
if self.verbose:
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
@ -155,7 +157,7 @@ class NATSsize(NASBenchMetaAPI):
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
@ -177,7 +179,8 @@ class NATSsize(NASBenchMetaAPI):
When is_random=False, the performanceo of all trials will be averaged.
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:

View File

@ -10,6 +10,8 @@ import os, copy, random, numpy as np
from pathlib import Path
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
import warnings
from .api_utils import time_string
from .api_utils import pickle_load
from .api_utils import ArchResults
from .api_utils import NASBenchMetaAPI
@ -60,58 +62,89 @@ class NATStopology(NASBenchMetaAPI):
self.reset_time()
if file_path_or_dict is None:
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
print ('Try to use the default NATS-Bench (topology) path from {:}.'.format(file_path_or_dict))
print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict))
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
file_path_or_dict = str(file_path_or_dict)
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
if verbose:
print('{:} Try to create the NATS-Bench (topology) api from {:}'.format(time_string(), file_path_or_dict))
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
self.filename = Path(file_path_or_dict).name
file_path_or_dict = np.load(file_path_or_dict)
if fast_mode:
if os.path.isfile(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
else:
self._archive_dir = file_path_or_dict
else:
if os.path.isdir(file_path_or_dict):
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
else:
file_path_or_dict = pickle_load(file_path_or_dict)
elif isinstance(file_path_or_dict, dict):
file_path_or_dict = copy.deepcopy(file_path_or_dict)
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set(['12', '200'])
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
if isinstance(file_path_or_dict, dict):
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
all_info = file_path_or_dict['arch2infos'][xkey]
hp2archres = OrderedDict()
for hp_key, results in all_infos.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
self.arch2infos_dict[xkey] = hp2archres
self.evaluated_indexes = list(file_path_or_dict['evaluated_indexes'])
elif self.archive_dir is not None:
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
self.arch2infos_dict = OrderedDict()
self._avaliable_hps = set()
self.evaluated_indexes = set()
else:
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
self.archstr2index = {}
for idx, arch in enumerate(self.meta_archs):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
self.archstr2index[arch] = idx
if self.verbose:
print('{:} Create NATS-Bench (topology) done with {:}/{:} architectures avaliable.'.format(
time_string(), len(self.evaluated_indexes), len(self.meta_archs)))
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space.
It will load its data from 'archive_root'.
"""
if self.verbose:
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
time_string(), archive_root, index))
if archive_root is None:
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
if not os.path.isdir(archive_root):
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
archive_root = self.archive_dir
if archive_root is None or not os.path.isdir(archive_root):
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
if index is None:
indexes = list(range(len(self)))
else:
indexes = [index]
for idx in indexes:
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
if not os.path.isfile(xfile_path):
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
xdata = torch.load(xfile_path, map_location='cpu')
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
xdata = pickle_load(xfile_path)
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
self.evaluated_indexes.add(idx)
hp2archres = OrderedDict()
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
for hp_key, results in xdata.items():
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
self._avaliable_hps.add(hp_key)
self.arch2infos_dict[idx] = hp2archres
def query_info_str_by_arch(self, arch, hp: Text='12'):
@ -122,7 +155,7 @@ class NATStopology(NASBenchMetaAPI):
The difference between these three configurations are the number of training epochs.
"""
if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
print('{:} Call query_info_str_by_arch with arch={:} and hp={:}'.format(time_string(), arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture
@ -142,8 +175,10 @@ class NATStopology(NASBenchMetaAPI):
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
print('{:} Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(
time_string(), index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
self._prepare_info(index)
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]

View File

@ -10,9 +10,9 @@
# History:
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
#
import os, abc, copy, random, numpy as np
import os, abc, time, copy, random, numpy as np
import bz2, pickle
import importlib, warnings
import warnings
from typing import List, Text, Union, Dict, Optional
from collections import OrderedDict, defaultdict
@ -36,6 +36,12 @@ def pickle_load(file_path, ext='.pbz2'):
return pickle.load(cfile)
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
"""re-map the metric_on_set to internal keys"""
if verbose:
@ -136,7 +142,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
"""
if self.verbose:
print('Call query_index_by_arch with arch={:}'.format(arch))
print('{:} Call query_index_by_arch with arch={:}'.format(time_string(), arch))
if isinstance(arch, int):
if 0 <= arch < len(self):
return arch
@ -162,13 +168,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
self.reload(self.archive_dir, index)
elif not self.fast_mode:
if self.verbose:
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is not the fast mode.'.format(time_string(), index))
else:
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
else:
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
if self.verbose:
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
@ -185,7 +191,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
"""
if self.verbose:
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
print('{:} Call clear_params with index={:} and hp={:}'.format(time_string(), index, hp))
if index not in self.arch2infos_dict:
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
elif hp is None:
@ -243,7 +249,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
"""
if self.verbose:
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
print('{:} Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(time_string(), arch_index, dataname, hp))
info = self.query_meta_info_by_index(arch_index, hp)
if dataname is None: return info
else:
@ -254,7 +260,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
"""Find the architecture with the highest accuracy based on some constraints."""
if self.verbose:
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(
time_string(), dataset, metric_on_set, hp, FLOP_max, Param_max))
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
best_index, highest_accuracy = -1, None
evaluated_indexes = sorted(list(self.evaluated_indexes))
@ -287,7 +294,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
-- 200 : train the model by 200 epochs
"""
if self.verbose:
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
print('{:} Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
info = self.query_meta_info_by_index(index, hp)
return info.get_net_param(dataset, seed)
@ -304,7 +311,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
config = api.get_net_config(128, 'cifar10')
"""
if self.verbose:
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
print('{:} Call the get_net_config function with index={:}, dataset={:}.'.format(time_string(), index, dataset))
self._prepare_info(index)
if index in self.arch2infos_dict:
info = self.arch2infos_dict[index]
@ -318,7 +325,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
if self.verbose:
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
self._prepare_info(index)
info = self.query_meta_info_by_index(index, hp)
return info.get_compute_costs(dataset)
@ -331,7 +338,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
:return: return a float value in seconds
"""
if self.verbose:
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
print('{:} Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
cost_dict = self.get_cost_info(index, dataset, hp)
return cost_dict['latency']