diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md index 7042bca..2a3e72d 100644 --- a/docs/NAS-Bench-201.md +++ b/docs/NAS-Bench-201.md @@ -14,7 +14,7 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`. You can simply type `pip install nas-bench-201` to install our api. Please see source codes of `nas-bench-201` module in [this repo](https://github.com/D-X-Y/NAS-Bench-201). -If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me. +**If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.** ### Preparation and Download diff --git a/lib/nas_201_api/api.py b/lib/nas_201_api/api.py index 9712fe2..1c9e111 100644 --- a/lib/nas_201_api/api.py +++ b/lib/nas_201_api/api.py @@ -292,6 +292,11 @@ class NASBench201API(object): xifo['est-valid-accuracy'] = est_valid_info['accuracy'] return xifo + """ + This function will print the information of a specific (or all) architecture(s). + If the index < 0: it will loop for all architectures and print their information one by one. + else: it will print the information of the 'index'-th archiitecture. + """ def show(self, index=-1): if index < 0: # show all architectures print(self) @@ -299,10 +304,10 @@ class NASBench201API(object): print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10) print('arch : {:}'.format(self.meta_archs[idx])) strings = print_information(self.arch2infos_full[idx]) - print('>' * 40 + ' 200 epochs ' + '>' * 40) + print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40) print('\n'.join(strings)) strings = print_information(self.arch2infos_less[idx]) - print('>' * 40 + ' 12 epochs ' + '>' * 40) + print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40) print('\n'.join(strings)) print('<' * 40 + '------------' + '<' * 40) else: @@ -310,10 +315,10 @@ class NASBench201API(object): if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index)) else: strings = print_information(self.arch2infos_full[index]) - print('>' * 40 + ' 200 epochs ' + '>' * 40) + print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40) print('\n'.join(strings)) strings = print_information(self.arch2infos_less[index]) - print('>' * 40 + ' 12 epochs ' + '>' * 40) + print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40) print('\n'.join(strings)) print('<' * 40 + '------------' + '<' * 40) else: @@ -419,7 +424,7 @@ class ArchResults(object): -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' ------ 'train' : the metric on the training set. ------ 'x-valid' : the metric on the validation set. - ------ 'ori-test' : the metric on the validation + test set. + ------ 'ori-test' : the metric on the test set. -- When dataset = cifar10, you can use 'train', 'ori-test'. ------ 'train' : the metric on the training + validation set. ------ 'ori-test' : the metric on the test set. @@ -472,6 +477,11 @@ class ArchResults(object): def get_dataset_seeds(self, dataset): return copy.deepcopy( self.dataset_seed[dataset] ) + """ + This function will return the trained network's weights on the 'dataset'. + When the 'seed' is None, it will return the weights for every run trial in the form of a dict. + When the + """ def get_net_param(self, dataset, seed=None): if seed is None: x_seeds = self.dataset_seed[dataset] @@ -479,6 +489,21 @@ class ArchResults(object): else: return self.all_results[(dataset, seed)].get_net_param() + # get the total number of training epochs + def get_total_epoch(self, dataset=None): + if dataset is None: + epochss = [] + for xdata, x_seeds in self.dataset_seed.items(): + epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds] + elif isinstance(dataset, str): + x_seeds = self.dataset_seed[dataset] + epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds] + else: + raise ValueError('invalid dataset={:}'.format(dataset)) + if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss)) + return epochss[-1] + + # return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed' def query(self, dataset, seed=None): if seed is None: x_seeds = self.dataset_seed[dataset] @@ -537,6 +562,8 @@ class ArchResults(object): x.load_state_dict(state_dict) return x + # This function is used to clear the weights saved in each 'result' + # This can help reduce the memory footprint. def clear_params(self): for key, result in self.all_results.items(): result.net_state_dict = None @@ -547,6 +574,11 @@ class ArchResults(object): +""" +This class (ResultsCount) is used to save the information of one trial for a single architecture. +I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. +If you have any question regarding this class, please open an issue or email me. +""" class ResultsCount(object): def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency): @@ -604,10 +636,17 @@ class ResultsCount(object): set_name = '[' + ', '.join(self.eval_names) + ']' return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name)) + # get the total number of training epochs + def get_total_epoch(self): + return copy.deepcopy(self.epochs) + + # get the latency + # -1 represents not avaliable ; otherwise it should be a float value def get_latency(self): if self.latency is None: return -1 else: return sum(self.latency) / len(self.latency) + # get the information regarding time def get_times(self): if self.train_times is not None and isinstance(self.train_times, dict): train_times = list( self.train_times.values() ) @@ -626,6 +665,7 @@ class ResultsCount(object): def get_eval_set(self): return self.eval_names + # get the training information def get_train(self, iepoch=None): if iepoch is None: iepoch = self.epochs-1 assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) @@ -639,6 +679,7 @@ class ResultsCount(object): 'cur_time': xtime, 'all_time': atime} + # get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument). def get_eval(self, name, iepoch=None): if iepoch is None: iepoch = self.epochs-1 assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)