update NAS-Bench-102
This commit is contained in:
parent
69ca0860aa
commit
95ec4d328e
3
.gitignore
vendored
3
.gitignore
vendored
@ -115,3 +115,6 @@ GPU-*.sh
|
|||||||
cal.sh
|
cal.sh
|
||||||
aaa
|
aaa
|
||||||
cx.sh
|
cx.sh
|
||||||
|
|
||||||
|
NAS-Bench-102-v1_0.pth
|
||||||
|
lib/NAS-Bench-102-v1_0.pth
|
||||||
|
@ -6,11 +6,16 @@ Each edge here is associated with an operation selected from a predefined operat
|
|||||||
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
||||||
|
|
||||||
In this Markdown file, we provide:
|
In this Markdown file, we provide:
|
||||||
- Detailed instruction to reproduce NAS-Bench-102.
|
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
||||||
- 10 NAS algorithms evaluated in our paper.
|
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
||||||
|
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
||||||
|
|
||||||
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||||
|
|
||||||
|
The data file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan].
|
||||||
|
|
||||||
|
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan].
|
||||||
|
|
||||||
## How to Use NAS-Bench-102
|
## How to Use NAS-Bench-102
|
||||||
|
|
||||||
1. Creating an API instance from a file:
|
1. Creating an API instance from a file:
|
||||||
@ -35,8 +40,8 @@ api.show(2)
|
|||||||
|
|
||||||
# show the mean loss and accuracy of an architecture
|
# show the mean loss and accuracy of an architecture
|
||||||
info = api.query_meta_info_by_index(1)
|
info = api.query_meta_info_by_index(1)
|
||||||
loss, accuracy = info.get_metrics('cifar10', 'train')
|
res_metrics = info.get_metrics('cifar10', 'train')
|
||||||
flops, params, latency = info.get_comput_costs('cifar100')
|
cost_metrics = info.get_comput_costs('cifar100')
|
||||||
|
|
||||||
# get the detailed information
|
# get the detailed information
|
||||||
results = api.query_by_index(1, 'cifar100')
|
results = api.query_by_index(1, 'cifar100')
|
||||||
@ -55,7 +60,8 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
|
|||||||
api.show(index)
|
api.show(index)
|
||||||
```
|
```
|
||||||
|
|
||||||
5. For other usages, please see `lib/aa_nas_api/api.py`
|
5. For other usages, please see `lib/nas_102_api/api.py`
|
||||||
|
|
||||||
|
|
||||||
### Detailed Instruction
|
### Detailed Instruction
|
||||||
|
|
||||||
@ -98,8 +104,10 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss
|
|||||||
```
|
```
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_102_api import NASBench102API as API
|
||||||
api = API('NAS-Bench-102-v1_0.pth')
|
api = API('NAS-Bench-102-v1_0.pth')
|
||||||
|
api.show(-1) # show info of all architectures
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Instruction to Re-Generate NAS-Bench-102
|
## Instruction to Re-Generate NAS-Bench-102
|
||||||
|
|
||||||
1. generate the meta file for NAS-Bench-102 using the following script, where `NAS-BENCH-102` indicates the name and `4` indicates the maximum number of nodes in a cell.
|
1. generate the meta file for NAS-Bench-102 using the following script, where `NAS-BENCH-102` indicates the name and `4` indicates the maximum number of nodes in a cell.
|
||||||
@ -139,6 +147,7 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet
|
|||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-102
|
## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-102
|
||||||
|
|
||||||
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
|
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
#################################################################################
|
||||||
|
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search #
|
||||||
|
#################################################################################
|
||||||
import os, sys, copy, random, torch, numpy as np
|
import os, sys, copy, random, torch, numpy as np
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
|
||||||
@ -12,19 +14,21 @@ def print_information(information, extra_info=None, show=False):
|
|||||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||||
|
|
||||||
for ida, dataset in enumerate(dataset_names):
|
for ida, dataset in enumerate(dataset_names):
|
||||||
flop, param, latency = information.get_comput_costs(dataset)
|
#flop, param, latency = information.get_comput_costs(dataset)
|
||||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency > 0 else None)
|
metric = information.get_comput_costs(dataset)
|
||||||
train_loss, train_acc = information.get_metrics(dataset, 'train')
|
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||||
|
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||||
|
train_info = information.get_metrics(dataset, 'train')
|
||||||
if dataset == 'cifar10-valid':
|
if dataset == 'cifar10-valid':
|
||||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc))
|
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||||
elif dataset == 'cifar10':
|
elif dataset == 'cifar10':
|
||||||
test__loss, test__acc = information.get_metrics(dataset, 'ori-test')
|
test__info = information.get_metrics(dataset, 'ori-test')
|
||||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(test__loss, test__acc))
|
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||||
else:
|
else:
|
||||||
valid_loss, valid_acc = information.get_metrics(dataset, 'x-valid')
|
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||||
test__loss, test__acc = information.get_metrics(dataset, 'x-test')
|
test__info = information.get_metrics(dataset, 'x-test')
|
||||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_loss, train_acc), metric2str(valid_loss, valid_acc), metric2str(test__loss, test__acc))
|
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||||
strings += [str1, str2]
|
strings += [str1, str2]
|
||||||
if show: print('\n'.join(strings))
|
if show: print('\n'.join(strings))
|
||||||
return strings
|
return strings
|
||||||
@ -34,19 +38,21 @@ class NASBench102API(object):
|
|||||||
|
|
||||||
def __init__(self, file_path_or_dict, verbose=True):
|
def __init__(self, file_path_or_dict, verbose=True):
|
||||||
if isinstance(file_path_or_dict, str):
|
if isinstance(file_path_or_dict, str):
|
||||||
if verbose: print('try to create NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||||
file_path_or_dict = torch.load(file_path_or_dict)
|
file_path_or_dict = torch.load(file_path_or_dict)
|
||||||
else:
|
else:
|
||||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||||
import pdb; pdb.set_trace() # we will update this api soon
|
|
||||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||||
self.arch2infos = OrderedDict()
|
self.arch2infos_less = OrderedDict()
|
||||||
|
self.arch2infos_full = OrderedDict()
|
||||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||||
self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] )
|
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||||
|
self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
|
||||||
|
self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
|
||||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||||
self.archstr2index = {}
|
self.archstr2index = {}
|
||||||
for idx, arch in enumerate(self.meta_archs):
|
for idx, arch in enumerate(self.meta_archs):
|
||||||
@ -73,35 +79,46 @@ class NASBench102API(object):
|
|||||||
else: arch_index = -1
|
else: arch_index = -1
|
||||||
return arch_index
|
return arch_index
|
||||||
|
|
||||||
def query_by_arch(self, arch):
|
def query_by_arch(self, arch, use_12epochs_result=False):
|
||||||
|
if isinstance(arch, int):
|
||||||
|
arch_index = arch
|
||||||
|
else:
|
||||||
arch_index = self.query_index_by_arch(arch)
|
arch_index = self.query_index_by_arch(arch)
|
||||||
if arch_index == -1: return None
|
if arch_index == -1: return None # the following two lines are used to support few training epochs
|
||||||
if arch_index in self.arch2infos:
|
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||||
strings = print_information(self.arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
|
else : arch2infos = self.arch2infos_full
|
||||||
|
if arch_index in arch2infos:
|
||||||
|
strings = print_information(arch2infos[ arch_index ], 'arch-index={:}'.format(arch_index))
|
||||||
return '\n'.join(strings)
|
return '\n'.join(strings)
|
||||||
else:
|
else:
|
||||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query_by_index(self, arch_index, dataname):
|
def query_by_index(self, arch_index, dataname, use_12epochs_result=False):
|
||||||
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
|
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||||
|
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||||
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
assert dataname in archInfo.get_dataset_names(), 'invalid dataset-name : {:}'.format(dataname)
|
||||||
info = archInfo.query(dataname)
|
info = archInfo.query(dataname)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def query_meta_info_by_index(self, arch_index):
|
def query_meta_info_by_index(self, arch_index, use_12epochs_result=False):
|
||||||
assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index)
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
archInfo = copy.deepcopy( self.arch2infos[ arch_index ] )
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
|
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||||
|
archInfo = copy.deepcopy( arch2infos[ arch_index ] )
|
||||||
return archInfo
|
return archInfo
|
||||||
|
|
||||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None):
|
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
|
||||||
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
best_index, highest_accuracy = -1, None
|
best_index, highest_accuracy = -1, None
|
||||||
for i, idx in enumerate(self.evaluated_indexes):
|
for i, idx in enumerate(self.evaluated_indexes):
|
||||||
flop, param, latency = self.arch2infos[idx].get_comput_costs(dataset)
|
flop, param, latency = arch2infos[idx].get_comput_costs(dataset)
|
||||||
if FLOP_max is not None and flop > FLOP_max : continue
|
if FLOP_max is not None and flop > FLOP_max : continue
|
||||||
if Param_max is not None and param > Param_max: continue
|
if Param_max is not None and param > Param_max: continue
|
||||||
loss, accuracy = self.arch2infos[idx].get_metrics(dataset, metric_on_set)
|
loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set)
|
||||||
if best_index == -1:
|
if best_index == -1:
|
||||||
best_index, highest_accuracy = idx, accuracy
|
best_index, highest_accuracy = idx, accuracy
|
||||||
elif highest_accuracy < accuracy:
|
elif highest_accuracy < accuracy:
|
||||||
@ -113,21 +130,29 @@ class NASBench102API(object):
|
|||||||
return copy.deepcopy(self.meta_archs[index])
|
return copy.deepcopy(self.meta_archs[index])
|
||||||
|
|
||||||
def show(self, index=-1):
|
def show(self, index=-1):
|
||||||
if index == -1: # show all architectures
|
if index < 0: # show all architectures
|
||||||
print(self)
|
print(self)
|
||||||
for i, idx in enumerate(self.evaluated_indexes):
|
for i, idx in enumerate(self.evaluated_indexes):
|
||||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||||
strings = print_information(self.arch2infos[idx])
|
strings = print_information(self.arch2infos_full[idx])
|
||||||
print('>' * 20)
|
print('>' * 40 + ' 200 epochs ' + '>' * 40)
|
||||||
print('\n'.join(strings))
|
print('\n'.join(strings))
|
||||||
print('<' * 20)
|
strings = print_information(self.arch2infos_less[idx])
|
||||||
|
print('>' * 40 + ' 12 epochs ' + '>' * 40)
|
||||||
|
print('\n'.join(strings))
|
||||||
|
print('<' * 40 + '------------' + '<' * 40)
|
||||||
else:
|
else:
|
||||||
if 0 <= index < len(self.meta_archs):
|
if 0 <= index < len(self.meta_archs):
|
||||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||||
else:
|
else:
|
||||||
strings = print_information(self.arch2infos[index])
|
strings = print_information(self.arch2infos_full[index])
|
||||||
|
print('>' * 40 + ' 200 epochs ' + '>' * 40)
|
||||||
print('\n'.join(strings))
|
print('\n'.join(strings))
|
||||||
|
strings = print_information(self.arch2infos_less[index])
|
||||||
|
print('>' * 40 + ' 12 epochs ' + '>' * 40)
|
||||||
|
print('\n'.join(strings))
|
||||||
|
print('<' * 40 + '------------' + '<' * 40)
|
||||||
else:
|
else:
|
||||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user