fix NAS-Bench-201 comments and add more

This commit is contained in:
D-X-Y 2020-02-04 18:22:08 +11:00
parent 25e529f788
commit 11d3c21467
2 changed files with 47 additions and 6 deletions

View File

@ -14,7 +14,7 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
You can simply type `pip install nas-bench-201` to install our api. Please see source codes of `nas-bench-201` module in [this repo](https://github.com/D-X-Y/NAS-Bench-201).
If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.
**If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.**
### Preparation and Download

View File

@ -292,6 +292,11 @@ class NASBench201API(object):
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
return xifo
"""
This function will print the information of a specific (or all) architecture(s).
If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture.
"""
def show(self, index=-1):
if index < 0: # show all architectures
print(self)
@ -299,10 +304,10 @@ class NASBench201API(object):
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
print('arch : {:}'.format(self.meta_archs[idx]))
strings = print_information(self.arch2infos_full[idx])
print('>' * 40 + ' 200 epochs ' + '>' * 40)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
strings = print_information(self.arch2infos_less[idx])
print('>' * 40 + ' 12 epochs ' + '>' * 40)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
@ -310,10 +315,10 @@ class NASBench201API(object):
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
else:
strings = print_information(self.arch2infos_full[index])
print('>' * 40 + ' 200 epochs ' + '>' * 40)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
strings = print_information(self.arch2infos_less[index])
print('>' * 40 + ' 12 epochs ' + '>' * 40)
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40)
print('\n'.join(strings))
print('<' * 40 + '------------' + '<' * 40)
else:
@ -419,7 +424,7 @@ class ArchResults(object):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the validation + test set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
@ -472,6 +477,11 @@ class ArchResults(object):
def get_dataset_seeds(self, dataset):
return copy.deepcopy( self.dataset_seed[dataset] )
"""
This function will return the trained network's weights on the 'dataset'.
When the 'seed' is None, it will return the weights for every run trial in the form of a dict.
When the
"""
def get_net_param(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
@ -479,6 +489,21 @@ class ArchResults(object):
else:
return self.all_results[(dataset, seed)].get_net_param()
# get the total number of training epochs
def get_total_epoch(self, dataset=None):
if dataset is None:
epochss = []
for xdata, x_seeds in self.dataset_seed.items():
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
elif isinstance(dataset, str):
x_seeds = self.dataset_seed[dataset]
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
else:
raise ValueError('invalid dataset={:}'.format(dataset))
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
return epochss[-1]
# return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'
def query(self, dataset, seed=None):
if seed is None:
x_seeds = self.dataset_seed[dataset]
@ -537,6 +562,8 @@ class ArchResults(object):
x.load_state_dict(state_dict)
return x
# This function is used to clear the weights saved in each 'result'
# This can help reduce the memory footprint.
def clear_params(self):
for key, result in self.all_results.items():
result.net_state_dict = None
@ -547,6 +574,11 @@ class ArchResults(object):
"""
This class (ResultsCount) is used to save the information of one trial for a single architecture.
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
If you have any question regarding this class, please open an issue or email me.
"""
class ResultsCount(object):
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
@ -604,10 +636,17 @@ class ResultsCount(object):
set_name = '[' + ', '.join(self.eval_names) + ']'
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
# get the total number of training epochs
def get_total_epoch(self):
return copy.deepcopy(self.epochs)
# get the latency
# -1 represents not avaliable ; otherwise it should be a float value
def get_latency(self):
if self.latency is None: return -1
else: return sum(self.latency) / len(self.latency)
# get the information regarding time
def get_times(self):
if self.train_times is not None and isinstance(self.train_times, dict):
train_times = list( self.train_times.values() )
@ -626,6 +665,7 @@ class ResultsCount(object):
def get_eval_set(self):
return self.eval_names
# get the training information
def get_train(self, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
@ -639,6 +679,7 @@ class ResultsCount(object):
'cur_time': xtime,
'all_time': atime}
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
def get_eval(self, name, iepoch=None):
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)