883 lines
42 KiB
Python
883 lines
42 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
|
############################################################################################
|
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
|
############################################################################################
|
|
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
|
|
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
|
|
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
|
|
############################################################################################
|
|
# History:
|
|
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
|
#
|
|
import os, abc, time, copy, random, numpy as np
|
|
import bz2, pickle
|
|
import warnings
|
|
from typing import List, Text, Union, Dict, Optional
|
|
from collections import OrderedDict, defaultdict
|
|
|
|
|
|
_FILE_SYSTEM = 'default'
|
|
PICKLE_EXT = 'pickle.pbz2'
|
|
|
|
|
|
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
|
|
"""Use pickle to save data (obj) into file_path.
|
|
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
|
|
"""
|
|
# with open(file_path, 'wb') as cfile:
|
|
with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
|
|
pickle.dump(obj, cfile, protocol=protocol)
|
|
|
|
|
|
def pickle_load(file_path, ext='.pbz2'):
|
|
# return pickle.load(open(file_path, "rb"))
|
|
if os.path.isfile(str(file_path)):
|
|
xfile_path = str(file_path)
|
|
else:
|
|
xfile_path = str(file_path) + ext
|
|
with bz2.BZ2File(xfile_path, 'rb') as cfile:
|
|
return pickle.load(cfile)
|
|
|
|
|
|
def time_string():
|
|
ISOTIMEFORMAT='%Y-%m-%d %X'
|
|
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
|
return string
|
|
|
|
|
|
def reset_file_system(lib: Text='default'):
|
|
_FILE_SYSTEM = lib
|
|
|
|
|
|
def get_file_system(lib: Text='default'):
|
|
return _FILE_SYSTEM
|
|
|
|
|
|
def nats_is_dir(file_path):
|
|
if _FILE_SYSTEM == 'default':
|
|
return os.path.isdir(file_path)
|
|
elif _FILE_SYSTEM == 'google':
|
|
import tensorflow as tf
|
|
return tf.gfile.isdir(file_path)
|
|
else:
|
|
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
|
|
|
|
|
|
def nats_is_file(file_path):
|
|
if _FILE_SYSTEM == 'default':
|
|
return os.path.isfile(file_path)
|
|
elif _FILE_SYSTEM == 'google':
|
|
import tensorflow as tf
|
|
return tf.gfile.exists(file_path) and not tf.gfile.isdir(file_path)
|
|
else:
|
|
raise ValueError('Unknown file system lib: {:}'.format(_FILE_SYSTEM))
|
|
|
|
|
|
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
|
"""re-map the metric_on_set to internal keys"""
|
|
if verbose:
|
|
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
|
if dataset == 'cifar10' and metric_on_set == 'valid':
|
|
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
|
|
elif dataset == 'cifar10' and metric_on_set == 'test':
|
|
dataset, metric_on_set = 'cifar10', 'ori-test'
|
|
elif dataset == 'cifar10' and metric_on_set == 'train':
|
|
dataset, metric_on_set = 'cifar10', 'train'
|
|
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
|
|
metric_on_set = 'x-valid'
|
|
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
|
|
metric_on_set = 'x-test'
|
|
if verbose:
|
|
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
|
return dataset, metric_on_set
|
|
|
|
|
|
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|
|
|
@abc.abstractmethod
|
|
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
|
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
|
|
|
def __getitem__(self, index: int):
|
|
return copy.deepcopy(self.meta_archs[index])
|
|
|
|
def arch(self, index: int):
|
|
"""Return the topology structure of the `index`-th architecture."""
|
|
if self.verbose:
|
|
print('Call the arch function with index={:}'.format(index))
|
|
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
|
return copy.deepcopy(self.meta_archs[index])
|
|
|
|
def __len__(self):
|
|
return len(self.meta_archs)
|
|
|
|
def __repr__(self):
|
|
return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, file={filename})'.format(
|
|
name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs),
|
|
fast_mode=self.fast_mode, filename=self.filename))
|
|
|
|
@property
|
|
def avaliable_hps(self):
|
|
return list(copy.deepcopy(self._avaliable_hps))
|
|
|
|
@property
|
|
def used_time(self):
|
|
return self._used_time
|
|
|
|
@property
|
|
def search_space_name(self):
|
|
return self._search_space_name
|
|
|
|
@property
|
|
def fast_mode(self):
|
|
return self._fast_mode
|
|
|
|
@property
|
|
def archive_dir(self):
|
|
return self._archive_dir
|
|
|
|
def reset_archive_dir(self, archive_dir):
|
|
self._archive_dir = archive_dir
|
|
|
|
def reset_fast_mode(self, fast_mode):
|
|
self._fast_mode = fast_mode
|
|
|
|
def reset_time(self):
|
|
self._used_time = 0
|
|
|
|
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
|
|
index = self.query_index_by_arch(arch)
|
|
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
|
|
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
|
|
if dataset == 'cifar10':
|
|
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
|
|
else:
|
|
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
|
|
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
|
latency = self.get_latency(index, dataset)
|
|
if account_time:
|
|
self._used_time += time_cost
|
|
return valid_acc, latency, time_cost, self._used_time
|
|
|
|
def random(self):
|
|
"""Return a random index of all architectures."""
|
|
return random.randint(0, len(self.meta_archs)-1)
|
|
|
|
def reload(self, archive_root: Text = None, index: int = None):
|
|
"""Overwrite all information of the 'index'-th architecture in the search space,
|
|
where the data will be loaded from 'archive_root'.
|
|
If archive_root is None, it will try to load from the default path os.environ['TORCH_HOME'] / 'BASE_NAME'-full.
|
|
If index is None, overwrite all ckps.
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call clear_params with archive_root={:} and index={:}'.format(
|
|
time_string(), archive_root, index))
|
|
if archive_root is None:
|
|
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(self.ALL_BASE_NAMES[-1]))
|
|
if not nats_is_dir(archive_root):
|
|
warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root))
|
|
archive_root = self.archive_dir
|
|
if archive_root is None or not nats_is_dir(archive_root):
|
|
raise ValueError('Invalid archive_root : {:}'.format(archive_root))
|
|
if index is None:
|
|
indexes = list(range(len(self)))
|
|
else:
|
|
indexes = [index]
|
|
for idx in indexes:
|
|
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
|
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
|
|
if not nats_is_file(xfile_path):
|
|
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
|
|
assert nats_is_file(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
|
xdata = pickle_load(xfile_path)
|
|
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
|
|
self.evaluated_indexes.add(idx)
|
|
hp2archres = OrderedDict()
|
|
for hp_key, results in xdata.items():
|
|
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
|
self._avaliable_hps.add(hp_key)
|
|
self.arch2infos_dict[idx] = hp2archres
|
|
|
|
def query_index_by_arch(self, arch):
|
|
""" This function is used to query the index of an architecture in the search space.
|
|
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
|
|
or an instance that has the 'tostr' function that can generate the architecture string;
|
|
or it is directly an architecture index, in this case, we will check whether it is valid or not.
|
|
This function will return the index.
|
|
If return -1, it means this architecture is not in the search space.
|
|
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call query_index_by_arch with arch={:}'.format(time_string(), arch))
|
|
if isinstance(arch, int):
|
|
if 0 <= arch < len(self):
|
|
return arch
|
|
else:
|
|
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
|
|
elif isinstance(arch, str):
|
|
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
|
else : arch_index = -1
|
|
elif hasattr(arch, 'tostr'):
|
|
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
|
else : arch_index = -1
|
|
else: arch_index = -1
|
|
return arch_index
|
|
|
|
def query_by_arch(self, arch, hp):
|
|
"""This is to make the current version be compatible with the old version."""
|
|
return self.query_info_str_by_arch(arch, hp)
|
|
|
|
def _prepare_info(self, index):
|
|
"""This is a function to load the data from disk when using fast mode."""
|
|
if index not in self.arch2infos_dict:
|
|
if self.fast_mode and self.archive_dir is not None:
|
|
self.reload(self.archive_dir, index)
|
|
elif not self.fast_mode:
|
|
if self.verbose:
|
|
print('{:} Call _prepare_info with index={:} skip because it is not the fast mode.'.format(time_string(), index))
|
|
else:
|
|
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
|
|
else:
|
|
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
|
|
if self.verbose:
|
|
print('{:} Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(time_string(), index))
|
|
|
|
def clear_params(self, index: int, hp: Optional[Text]=None):
|
|
"""Remove the architecture's weights to save memory.
|
|
:arg
|
|
index: the index of the target architecture
|
|
hp: a flag to controll how to clear the parameters.
|
|
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
|
|
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call clear_params with index={:} and hp={:}'.format(time_string(), index, hp))
|
|
if index not in self.arch2infos_dict:
|
|
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
|
|
elif hp is None:
|
|
for key, result in self.arch2infos_dict[index].items():
|
|
result.clear_params()
|
|
else:
|
|
if str(hp) not in self.arch2infos_dict[index]:
|
|
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
|
|
self.arch2infos_dict[index][str(hp)].clear_params()
|
|
|
|
@abc.abstractmethod
|
|
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
|
"""This function is used to query the information of a specific architecture."""
|
|
|
|
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
|
arch_index = self.query_index_by_arch(arch)
|
|
self._prepare_info(arch_index)
|
|
if arch_index in self.arch2infos_dict:
|
|
if hp not in self.arch2infos_dict[arch_index]:
|
|
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
|
info = self.arch2infos_dict[arch_index][hp]
|
|
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
|
return '\n'.join(strings)
|
|
else:
|
|
warnings.warn('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
|
return None
|
|
|
|
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
|
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
|
if self.verbose:
|
|
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
|
self._prepare_info(arch_index)
|
|
if arch_index in self.arch2infos_dict:
|
|
if hp not in self.arch2infos_dict[arch_index]:
|
|
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
|
info = self.arch2infos_dict[arch_index][hp]
|
|
else:
|
|
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
|
|
return copy.deepcopy(info)
|
|
|
|
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
|
|
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
|
|
------
|
|
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
|
|
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
|
|
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
|
|
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
|
|
------
|
|
If dataname is None, return the ArchResults
|
|
else, return a dict with all trials on that dataset (the key is the seed)
|
|
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
|
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
|
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
|
-- cifar100 : training the model on the CIFAR-100 training set.
|
|
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(time_string(), arch_index, dataname, hp))
|
|
info = self.query_meta_info_by_index(arch_index, hp)
|
|
if dataname is None: return info
|
|
else:
|
|
if dataname not in info.get_dataset_names():
|
|
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
|
|
return info.query(dataname)
|
|
|
|
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
|
|
"""Find the architecture with the highest accuracy based on some constraints."""
|
|
if self.verbose:
|
|
print('{:} Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(
|
|
time_string(), dataset, metric_on_set, hp, FLOP_max, Param_max))
|
|
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
|
best_index, highest_accuracy = -1, None
|
|
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
|
for i, arch_index in enumerate(evaluated_indexes):
|
|
arch_info = self.arch2infos_dict[arch_index][hp]
|
|
info = arch_info.get_compute_costs(dataset) # the information of costs
|
|
flop, param, latency = info['flops'], info['params'], info['latency']
|
|
if FLOP_max is not None and flop > FLOP_max : continue
|
|
if Param_max is not None and param > Param_max: continue
|
|
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
|
|
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
|
if best_index == -1:
|
|
best_index, highest_accuracy = arch_index, accuracy
|
|
elif highest_accuracy < accuracy:
|
|
best_index, highest_accuracy = arch_index, accuracy
|
|
if self.verbose:
|
|
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
|
|
return best_index, highest_accuracy
|
|
|
|
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
|
|
"""
|
|
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
|
Args [seed]:
|
|
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
|
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
|
Args [hp]:
|
|
-- 01 : train the model by 01 epochs
|
|
-- 12 : train the model by 12 epochs
|
|
-- 90 : train the model by 90 epochs
|
|
-- 200 : train the model by 200 epochs
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(time_string(), index, dataset, seed, hp))
|
|
info = self.query_meta_info_by_index(index, hp)
|
|
return info.get_net_param(dataset, seed)
|
|
|
|
def get_net_config(self, index: int, dataset: Text):
|
|
"""
|
|
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
|
Args [dataset] (4 possible options):
|
|
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
|
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
|
-- cifar100 : training the model on the CIFAR-100 training set.
|
|
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
|
This function will return a dict.
|
|
========= Some examlpes for using this function:
|
|
config = api.get_net_config(128, 'cifar10')
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call the get_net_config function with index={:}, dataset={:}.'.format(time_string(), index, dataset))
|
|
self._prepare_info(index)
|
|
if index in self.arch2infos_dict:
|
|
info = self.arch2infos_dict[index]
|
|
else:
|
|
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(index))
|
|
info = next(iter(info.values()))
|
|
results = info.query(dataset, None)
|
|
results = next(iter(results.values()))
|
|
return results.get_config(None)
|
|
|
|
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
|
|
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
|
if self.verbose:
|
|
print('{:} Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
|
|
self._prepare_info(index)
|
|
info = self.query_meta_info_by_index(index, hp)
|
|
return info.get_compute_costs(dataset)
|
|
|
|
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
|
|
"""
|
|
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
|
:param index: the index of the target architecture
|
|
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
|
|
:return: return a float value in seconds
|
|
"""
|
|
if self.verbose:
|
|
print('{:} Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(time_string(), index, dataset, hp))
|
|
cost_dict = self.get_cost_info(index, dataset, hp)
|
|
return cost_dict['latency']
|
|
|
|
@abc.abstractmethod
|
|
def show(self, index=-1):
|
|
"""This function will print the information of a specific (or all) architecture(s)."""
|
|
|
|
def _show(self, index=-1, print_information=None) -> None:
|
|
"""
|
|
This function will print the information of a specific (or all) architecture(s).
|
|
|
|
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
|
else: it will print the information of the 'index'-th architecture.
|
|
:return: nothing
|
|
"""
|
|
if index < 0: # show all architectures
|
|
print(self)
|
|
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
|
for i, idx in enumerate(evaluated_indexes):
|
|
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
|
|
print('arch : {:}'.format(self.meta_archs[idx]))
|
|
for key, result in self.arch2infos_dict[index].items():
|
|
strings = print_information(result)
|
|
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
|
print('\n'.join(strings))
|
|
print('<' * 40 + '------------' + '<' * 40)
|
|
else:
|
|
if 0 <= index < len(self.meta_archs):
|
|
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
|
else:
|
|
arch_info = self.arch2infos_dict[index]
|
|
for key, result in self.arch2infos_dict[index].items():
|
|
strings = print_information(result)
|
|
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
|
print('\n'.join(strings))
|
|
print('<' * 40 + '------------' + '<' * 40)
|
|
else:
|
|
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
|
|
|
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
|
|
"""This function will count the number of total trials."""
|
|
if self.verbose:
|
|
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
|
|
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
|
if dataset not in valid_datasets:
|
|
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
|
nums, hp = defaultdict(lambda: 0), str(hp)
|
|
# for index in range(len(self)):
|
|
for index in self.evaluated_indexes:
|
|
archInfo = self.arch2infos_dict[index][hp]
|
|
dataset_seed = archInfo.dataset_seed
|
|
if dataset not in dataset_seed:
|
|
nums[0] += 1
|
|
else:
|
|
nums[len(dataset_seed[dataset])] += 1
|
|
return dict(nums)
|
|
|
|
|
|
class ArchResults(object):
|
|
|
|
def __init__(self, arch_index, arch_str):
|
|
self.arch_index = int(arch_index)
|
|
self.arch_str = copy.deepcopy(arch_str)
|
|
self.all_results = dict()
|
|
self.dataset_seed = dict()
|
|
self.clear_net_done = False
|
|
|
|
def get_compute_costs(self, dataset):
|
|
x_seeds = self.dataset_seed[dataset]
|
|
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
|
|
|
flops = [result.flop for result in results]
|
|
params = [result.params for result in results]
|
|
latencies = [result.get_latency() for result in results]
|
|
latencies = [x for x in latencies if x > 0]
|
|
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
|
|
time_infos = defaultdict(list)
|
|
for result in results:
|
|
time_info = result.get_times()
|
|
for key, value in time_info.items(): time_infos[key].append( value )
|
|
|
|
info = {'flops' : np.mean(flops),
|
|
'params' : np.mean(params),
|
|
'latency': mean_latency}
|
|
for key, value in time_infos.items():
|
|
if len(value) > 0 and value[0] is not None:
|
|
info[key] = np.mean(value)
|
|
else: info[key] = None
|
|
return info
|
|
|
|
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
|
"""
|
|
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
|
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
|
If some args return None or raise error, then it is not avaliable.
|
|
========================================
|
|
Args [dataset] (4 possible options):
|
|
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
|
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
|
-- cifar100 : training the model on the CIFAR-100 training set.
|
|
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
|
Args [setname] (each dataset has different setnames):
|
|
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
|
------ 'train' : the metric on the training set.
|
|
------ 'x-valid' : the metric on the validation set.
|
|
------ 'ori-test' : the metric on the test set.
|
|
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
|
------ 'train' : the metric on the training + validation set.
|
|
------ 'ori-test' : the metric on the test set.
|
|
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
|
------ 'train' : the metric on the training set.
|
|
------ 'x-valid' : the metric on the validation set.
|
|
------ 'x-test' : the metric on the test set.
|
|
------ 'ori-test' : the metric on the validation + test set.
|
|
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
|
------ None : return the metric after the last training epoch.
|
|
------ an integer i : return the metric after the i-th training epoch.
|
|
Args [is_random]:
|
|
------ True : return the metric of a randomly selected trial.
|
|
------ False : return the averaged metric of all avaliable trials.
|
|
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
|
"""
|
|
x_seeds = self.dataset_seed[dataset]
|
|
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
|
infos = defaultdict(list)
|
|
for result in results:
|
|
if setname == 'train':
|
|
info = result.get_train(iepoch)
|
|
else:
|
|
info = result.get_eval(setname, iepoch)
|
|
for key, value in info.items(): infos[key].append( value )
|
|
return_info = dict()
|
|
if isinstance(is_random, bool) and is_random: # randomly select one
|
|
index = random.randint(0, len(results)-1)
|
|
for key, value in infos.items(): return_info[key] = value[index]
|
|
elif isinstance(is_random, bool) and not is_random: # average
|
|
for key, value in infos.items():
|
|
if len(value) > 0 and value[0] is not None:
|
|
return_info[key] = np.mean(value)
|
|
else: return_info[key] = None
|
|
elif isinstance(is_random, int): # specify the seed
|
|
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
|
index = x_seeds.index(is_random)
|
|
for key, value in infos.items(): return_info[key] = value[index]
|
|
else:
|
|
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
|
return return_info
|
|
|
|
def show(self, is_print=False):
|
|
return print_information(self, None, is_print)
|
|
|
|
def get_dataset_names(self):
|
|
return list(self.dataset_seed.keys())
|
|
|
|
def get_dataset_seeds(self, dataset):
|
|
return copy.deepcopy( self.dataset_seed[dataset] )
|
|
|
|
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
|
|
"""
|
|
This function will return the trained network's weights on the 'dataset'.
|
|
:arg
|
|
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
|
|
seed: an integer indicates the seed value or None that indicates returing all trials.
|
|
"""
|
|
if seed is None:
|
|
x_seeds = self.dataset_seed[dataset]
|
|
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
|
else:
|
|
xkey = (dataset, seed)
|
|
if xkey in self.all_results:
|
|
return self.all_results[xkey].get_net_param()
|
|
else:
|
|
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
|
|
|
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
|
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
|
if seed is None:
|
|
for seed in self.dataset_seed[dataset]:
|
|
self.all_results[(dataset, seed)].update_latency([latency])
|
|
else:
|
|
self.all_results[(dataset, seed)].update_latency([latency])
|
|
|
|
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
|
|
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
|
|
if seed is None:
|
|
for seed in self.dataset_seed[dataset]:
|
|
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
|
else:
|
|
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
|
|
|
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
|
|
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
|
|
if seed is None:
|
|
for seed in self.dataset_seed[dataset]:
|
|
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
|
else:
|
|
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
|
|
|
def get_latency(self, dataset: Text) -> float:
|
|
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
|
|
latencies = []
|
|
for seed in self.dataset_seed[dataset]:
|
|
latency = self.all_results[(dataset, seed)].get_latency()
|
|
if not isinstance(latency, float) or latency <= 0:
|
|
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
|
|
latencies.append(latency)
|
|
return sum(latencies) / len(latencies)
|
|
|
|
def get_total_epoch(self, dataset=None):
|
|
"""Return the total number of training epochs."""
|
|
if dataset is None:
|
|
epochss = []
|
|
for xdata, x_seeds in self.dataset_seed.items():
|
|
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
|
elif isinstance(dataset, str):
|
|
x_seeds = self.dataset_seed[dataset]
|
|
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
|
else:
|
|
raise ValueError('invalid dataset={:}'.format(dataset))
|
|
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
|
return epochss[-1]
|
|
|
|
def query(self, dataset, seed=None):
|
|
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
|
|
if seed is None:
|
|
x_seeds = self.dataset_seed[dataset]
|
|
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
|
|
else:
|
|
return self.all_results[(dataset, seed)]
|
|
|
|
def arch_idx_str(self):
|
|
return '{:06d}'.format(self.arch_index)
|
|
|
|
def update(self, dataset_name, seed, result):
|
|
if dataset_name not in self.dataset_seed:
|
|
self.dataset_seed[dataset_name] = []
|
|
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
|
self.dataset_seed[ dataset_name ].append( seed )
|
|
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
|
assert (dataset_name, seed) not in self.all_results
|
|
self.all_results[ (dataset_name, seed) ] = result
|
|
self.clear_net_done = False
|
|
|
|
def state_dict(self):
|
|
state_dict = dict()
|
|
for key, value in self.__dict__.items():
|
|
if key == 'all_results': # contain the class of ResultsCount
|
|
xvalue = dict()
|
|
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
|
for _k, _v in value.items():
|
|
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
|
xvalue[_k] = _v.state_dict()
|
|
else:
|
|
xvalue = value
|
|
state_dict[key] = xvalue
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
new_state_dict = dict()
|
|
for key, value in state_dict.items():
|
|
if key == 'all_results': # to convert to the class of ResultsCount
|
|
xvalue = dict()
|
|
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
|
for _k, _v in value.items():
|
|
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
|
else: xvalue = value
|
|
new_state_dict[key] = xvalue
|
|
self.__dict__.update(new_state_dict)
|
|
|
|
@staticmethod
|
|
def create_from_state_dict(state_dict_or_file):
|
|
x = ArchResults(-1, -1)
|
|
if isinstance(state_dict_or_file, str): # a file path
|
|
state_dict = pickle_load(state_dict_or_file)
|
|
elif isinstance(state_dict_or_file, dict):
|
|
state_dict = state_dict_or_file
|
|
else:
|
|
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
|
x.load_state_dict(state_dict)
|
|
return x
|
|
|
|
# This function is used to clear the weights saved in each 'result'
|
|
# This can help reduce the memory footprint.
|
|
def clear_params(self):
|
|
for key, result in self.all_results.items():
|
|
del result.net_state_dict
|
|
result.net_state_dict = None
|
|
self.clear_net_done = True
|
|
|
|
def debug_test(self):
|
|
"""This function is used for me to debug and test, which will call most methods."""
|
|
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
|
for dataset in all_dataset:
|
|
print('---->>>> {:}'.format(dataset))
|
|
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
|
|
for seed in self.dataset_seed[dataset]:
|
|
result = self.all_results[(dataset, seed)]
|
|
print(' ==>> result = {:}'.format(result))
|
|
print(' ==>> cost = {:}'.format(result.get_times()))
|
|
|
|
def __repr__(self):
|
|
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
|
|
|
|
|
"""
|
|
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
|
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
|
If you have any question regarding this class, please open an issue or email me.
|
|
"""
|
|
class ResultsCount(object):
|
|
|
|
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
|
self.name = name
|
|
self.net_state_dict = state_dict
|
|
self.train_acc1es = copy.deepcopy(train_accs)
|
|
self.train_acc5es = None
|
|
self.train_losses = copy.deepcopy(train_losses)
|
|
self.train_times = None
|
|
self.arch_config = copy.deepcopy(arch_config)
|
|
self.params = params
|
|
self.flop = flop
|
|
self.seed = seed
|
|
self.epochs = epochs
|
|
self.latency = latency
|
|
# evaluation results
|
|
self.reset_eval()
|
|
|
|
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
|
|
self.train_acc1es = train_acc1es
|
|
self.train_acc5es = train_acc5es
|
|
self.train_losses = train_losses
|
|
self.train_times = train_times
|
|
|
|
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
|
|
"""Assign the training times."""
|
|
train_times = OrderedDict()
|
|
for i in range(self.epochs):
|
|
train_times[i] = estimated_per_epoch_time
|
|
self.train_times = train_times
|
|
|
|
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
|
|
"""Assign the evaluation times."""
|
|
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
|
|
for i in range(self.epochs):
|
|
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
|
|
|
|
def reset_eval(self):
|
|
self.eval_names = []
|
|
self.eval_acc1es = {}
|
|
self.eval_times = {}
|
|
self.eval_losses = {}
|
|
|
|
def update_latency(self, latency):
|
|
self.latency = copy.deepcopy( latency )
|
|
|
|
def get_latency(self) -> float:
|
|
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
|
|
if self.latency is None: return -1.0
|
|
else: return sum(self.latency) / len(self.latency)
|
|
|
|
def update_eval(self, accs, losses, times): # new version
|
|
data_names = set([x.split('@')[0] for x in accs.keys()])
|
|
for data_name in data_names:
|
|
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
|
self.eval_names.append( data_name )
|
|
for iepoch in range(self.epochs):
|
|
xkey = '{:}@{:}'.format(data_name, iepoch)
|
|
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
|
self.eval_losses[ xkey ] = losses[ xkey ]
|
|
self.eval_times [ xkey ] = times[ xkey ]
|
|
|
|
def update_OLD_eval(self, name, accs, losses): # old version
|
|
assert name not in self.eval_names, '{:} has already added'.format(name)
|
|
self.eval_names.append( name )
|
|
for iepoch in range(self.epochs):
|
|
if iepoch in accs:
|
|
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
|
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
|
|
|
def __repr__(self):
|
|
num_eval = len(self.eval_names)
|
|
set_name = '[' + ', '.join(self.eval_names) + ']'
|
|
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
|
|
|
def get_total_epoch(self):
|
|
return copy.deepcopy(self.epochs)
|
|
|
|
def get_times(self):
|
|
"""Obtain the information regarding both training and evaluation time."""
|
|
if self.train_times is not None and isinstance(self.train_times, dict):
|
|
train_times = list( self.train_times.values() )
|
|
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
|
else:
|
|
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
|
for name in self.eval_names:
|
|
try:
|
|
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
|
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
|
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
|
except:
|
|
time_info['T-{:}@epoch'.format(name)] = None
|
|
time_info['T-{:}@total'.format(name)] = None
|
|
return time_info
|
|
|
|
def get_eval_set(self):
|
|
return self.eval_names
|
|
|
|
# get the training information
|
|
def get_train(self, iepoch=None):
|
|
if iepoch is None: iepoch = self.epochs-1
|
|
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
|
if self.train_times is not None:
|
|
xtime = self.train_times[iepoch]
|
|
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
|
else: xtime, atime = None, None
|
|
return {'iepoch' : iepoch,
|
|
'loss' : self.train_losses[iepoch],
|
|
'accuracy': self.train_acc1es[iepoch],
|
|
'cur_time': xtime,
|
|
'all_time': atime}
|
|
|
|
def get_eval(self, name, iepoch=None):
|
|
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
|
if iepoch is None: iepoch = self.epochs-1
|
|
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
|
def _internal_query(xname):
|
|
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
|
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
|
|
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
|
|
else:
|
|
xtime, atime = None, None
|
|
return {'iepoch' : iepoch,
|
|
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
|
|
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
|
|
'cur_time': xtime,
|
|
'all_time': atime}
|
|
if name == 'valid':
|
|
return _internal_query('x-valid')
|
|
else:
|
|
return _internal_query(name)
|
|
|
|
def get_net_param(self, clone=False):
|
|
if clone: return copy.deepcopy(self.net_state_dict)
|
|
else: return self.net_state_dict
|
|
|
|
def get_config(self, str2structure):
|
|
"""This function is used to obtain the config dict for this architecture."""
|
|
if str2structure is None:
|
|
# In this case, this is architecture in the size search space of NATS-BENCH.
|
|
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
|
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
|
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
|
|
# In this case, this is architecture in the topology search space of NATS-BENCH.
|
|
else:
|
|
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
|
'N' : self.arch_config['num_cells'],
|
|
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
|
else:
|
|
# In this case, this is architecture in the size search space of NATS-BENCH.
|
|
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
|
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
|
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
|
|
# In this case, this is architecture in the topology search space of NATS-BENCH.
|
|
else:
|
|
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
|
'N' : self.arch_config['num_cells'],
|
|
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
|
|
|
def state_dict(self):
|
|
_state_dict = {key: value for key, value in self.__dict__.items()}
|
|
return _state_dict
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.__dict__.update(state_dict)
|
|
|
|
@staticmethod
|
|
def create_from_state_dict(state_dict):
|
|
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
|
x.load_state_dict(state_dict)
|
|
return x
|