Update NATS-Bench (tss version 1.0) and remove the trace of 301
This commit is contained in:
		| @@ -68,7 +68,7 @@ class NATSsize(NASBenchMetaAPI): | ||||
|         self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|       else: | ||||
|         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||
|       print ('Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(self._fast_mode, file_path_or_dict)) | ||||
|       print ('{:} Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict)) | ||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||
|       file_path_or_dict = str(file_path_or_dict) | ||||
|       if verbose: | ||||
| @@ -125,10 +125,15 @@ class NATSsize(NASBenchMetaAPI): | ||||
|        If index is None, overwrite all ckps. | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('{:} Call clear_params with archive_root={:} and index={:}'.format(time_string(), archive_root, index)) | ||||
|       print('{:} Call clear_params with archive_root={:} and index={:}'.format( | ||||
|             time_string(), archive_root, index)) | ||||
|     if archive_root is None: | ||||
|       archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1])) | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|       if not os.path.isdir(archive_root): | ||||
|         warnings.warn('The input archive_root is None and the default archive_root path ({:}) does not exist, try to use self.archive_dir.'.format(archive_root)) | ||||
|       archive_root = self.archive_dir | ||||
|     if archive_root is None or not os.path.isdir(archive_root): | ||||
|       raise ValueError('Invalid archive_root : {:}'.format(archive_root)) | ||||
|     if index is None: | ||||
|       indexes = list(range(len(self))) | ||||
|     else: | ||||
|   | ||||
| @@ -4,7 +4,7 @@ | ||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||
| ##################################################################################### | ||||
| # The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) # | ||||
| # [2020.08.31]                                                                      # | ||||
| # [2020.08.31] NATS-tss-v1_0-3ffb9.pickle.pbz2                                      # | ||||
| ##################################################################################### | ||||
| import os, copy, random, numpy as np | ||||
| from pathlib import Path | ||||
| @@ -19,14 +19,14 @@ from .api_utils import remap_dataset_set_names | ||||
|  | ||||
|  | ||||
| PICKLE_EXT = 'pickle.pbz2' | ||||
| ALL_BASE_NAMES = ['NATS-tss-v1_0-xxxxx'] | ||||
| ALL_BASE_NAMES = ['NATS-tss-v1_0-3ffb9'] | ||||
|  | ||||
|  | ||||
| def print_information(information, extra_info=None, show=False): | ||||
|   dataset_names = information.get_dataset_names() | ||||
|   strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)] | ||||
|   def metric2str(loss, acc): | ||||
|     return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc) | ||||
|     return 'loss = {:.3f} & top1 = {:.2f}%'.format(loss, acc) | ||||
|  | ||||
|   for ida, dataset in enumerate(dataset_names): | ||||
|     metric = information.get_compute_costs(dataset) | ||||
| @@ -61,12 +61,15 @@ class NATStopology(NASBenchMetaAPI): | ||||
|     self._archive_dir = None | ||||
|     self.reset_time() | ||||
|     if file_path_or_dict is None: | ||||
|       file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1]) | ||||
|       if self._fast_mode: | ||||
|         self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1])) | ||||
|       else: | ||||
|         file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT)) | ||||
|       print ('{:} Try to use the default NATS-Bench (topology) path from {:}.'.format(time_string(), file_path_or_dict)) | ||||
|     if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path): | ||||
|       file_path_or_dict = str(file_path_or_dict) | ||||
|       if verbose: | ||||
|         print('{:} Try to create the NATS-Bench (topology) api from {:}'.format(time_string(), file_path_or_dict)) | ||||
|         print('{:} Try to create the NATS-Bench (topology) api from {:} with fast_mode={:}'.format(time_string(), file_path_or_dict, fast_mode)) | ||||
|       if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict): | ||||
|         raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict)) | ||||
|       self.filename = Path(file_path_or_dict).name | ||||
| @@ -82,7 +85,7 @@ class NATStopology(NASBenchMetaAPI): | ||||
|           file_path_or_dict = pickle_load(file_path_or_dict) | ||||
|     elif isinstance(file_path_or_dict, dict): | ||||
|       file_path_or_dict = copy.deepcopy(file_path_or_dict) | ||||
|     self.verbose = verbose # [TODO] a flag indicating whether to print more logs | ||||
|     self.verbose = verbose | ||||
|     if isinstance(file_path_or_dict, dict): | ||||
|       keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|       for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
| @@ -91,13 +94,13 @@ class NATStopology(NASBenchMetaAPI): | ||||
|       self.arch2infos_dict = OrderedDict() | ||||
|       self._avaliable_hps = set() | ||||
|       for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|         all_info = file_path_or_dict['arch2infos'][xkey] | ||||
|         all_infos = file_path_or_dict['arch2infos'][xkey] | ||||
|         hp2archres = OrderedDict() | ||||
|         for hp_key, results in all_infos.items(): | ||||
|           hp2archres[hp_key] = ArchResults.create_from_state_dict(results) | ||||
|           self._avaliable_hps.add(hp_key)  # save the avaliable hyper-parameter | ||||
|         self.arch2infos_dict[xkey] = hp2archres | ||||
|       self.evaluated_indexes = list(file_path_or_dict['evaluated_indexes']) | ||||
|       self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes']) | ||||
|     elif self.archive_dir is not None: | ||||
|       benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT)) | ||||
|       self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs']) | ||||
| @@ -116,7 +119,7 @@ class NATStopology(NASBenchMetaAPI): | ||||
|  | ||||
|   def reload(self, archive_root: Text = None, index: int = None): | ||||
|     """Overwrite all information of the 'index'-th architecture in the search space. | ||||
|          It will load its data from 'archive_root'. | ||||
|        If index is None, overwrite all ckps. | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('{:} Call clear_params with archive_root={:} and index={:}'.format( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user