fix NAS-Bench-201 comments and add more
This commit is contained in:
parent
25e529f788
commit
11d3c21467
@ -14,7 +14,7 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||
|
||||
You can simply type `pip install nas-bench-201` to install our api. Please see source codes of `nas-bench-201` module in [this repo](https://github.com/D-X-Y/NAS-Bench-201).
|
||||
|
||||
If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.
|
||||
**If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.**
|
||||
|
||||
### Preparation and Download
|
||||
|
||||
|
@ -292,6 +292,11 @@ class NASBench201API(object):
|
||||
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
||||
return xifo
|
||||
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th archiitecture.
|
||||
"""
|
||||
def show(self, index=-1):
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
@ -299,10 +304,10 @@ class NASBench201API(object):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
strings = print_information(self.arch2infos_full[idx])
|
||||
print('>' * 40 + ' 200 epochs ' + '>' * 40)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
strings = print_information(self.arch2infos_less[idx])
|
||||
print('>' * 40 + ' 12 epochs ' + '>' * 40)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
@ -310,10 +315,10 @@ class NASBench201API(object):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
strings = print_information(self.arch2infos_full[index])
|
||||
print('>' * 40 + ' 200 epochs ' + '>' * 40)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
strings = print_information(self.arch2infos_less[index])
|
||||
print('>' * 40 + ' 12 epochs ' + '>' * 40)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
@ -419,7 +424,7 @@ class ArchResults(object):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
@ -472,6 +477,11 @@ class ArchResults(object):
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
"""
|
||||
This function will return the trained network's weights on the 'dataset'.
|
||||
When the 'seed' is None, it will return the weights for every run trial in the form of a dict.
|
||||
When the
|
||||
"""
|
||||
def get_net_param(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
@ -479,6 +489,21 @@ class ArchResults(object):
|
||||
else:
|
||||
return self.all_results[(dataset, seed)].get_net_param()
|
||||
|
||||
# get the total number of training epochs
|
||||
def get_total_epoch(self, dataset=None):
|
||||
if dataset is None:
|
||||
epochss = []
|
||||
for xdata, x_seeds in self.dataset_seed.items():
|
||||
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
||||
elif isinstance(dataset, str):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
||||
else:
|
||||
raise ValueError('invalid dataset={:}'.format(dataset))
|
||||
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
||||
return epochss[-1]
|
||||
|
||||
# return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'
|
||||
def query(self, dataset, seed=None):
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
@ -537,6 +562,8 @@ class ArchResults(object):
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
# This function is used to clear the weights saved in each 'result'
|
||||
# This can help reduce the memory footprint.
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
result.net_state_dict = None
|
||||
@ -547,6 +574,11 @@ class ArchResults(object):
|
||||
|
||||
|
||||
|
||||
"""
|
||||
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
||||
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
||||
If you have any question regarding this class, please open an issue or email me.
|
||||
"""
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
@ -604,10 +636,17 @@ class ResultsCount(object):
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
# get the total number of training epochs
|
||||
def get_total_epoch(self):
|
||||
return copy.deepcopy(self.epochs)
|
||||
|
||||
# get the latency
|
||||
# -1 represents not avaliable ; otherwise it should be a float value
|
||||
def get_latency(self):
|
||||
if self.latency is None: return -1
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
# get the information regarding time
|
||||
def get_times(self):
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
@ -626,6 +665,7 @@ class ResultsCount(object):
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
# get the training information
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
@ -639,6 +679,7 @@ class ResultsCount(object):
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
# get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).
|
||||
def get_eval(self, name, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
|
Loading…
Reference in New Issue
Block a user