update NAS-Bench
This commit is contained in:
		| @@ -7,7 +7,8 @@ | ||||
| # [2020.03.08] Next version (coming soon) | ||||
| # | ||||
| # | ||||
| import os, sys, copy, random, torch, numpy as np | ||||
| import os, copy, random, torch, numpy as np | ||||
| from typing import List, Text, Union, Dict, Any | ||||
| from collections import OrderedDict, defaultdict | ||||
|  | ||||
|  | ||||
| @@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201. | ||||
| class NASBench201API(object): | ||||
|  | ||||
|   """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|   def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
| @@ -69,7 +70,7 @@ class NASBench201API(object): | ||||
|       assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) | ||||
|       self.archstr2index[ arch ] = idx | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|   def __getitem__(self, index: int): | ||||
|     return copy.deepcopy( self.meta_archs[index] ) | ||||
|  | ||||
|   def __len__(self): | ||||
| @@ -99,7 +100,7 @@ class NASBench201API(object): | ||||
|  | ||||
|   # Overwrite all information of the 'index'-th architecture in the search space. | ||||
|   # It will load its data from 'archive_root'. | ||||
|   def reload(self, archive_root, index): | ||||
|   def reload(self, archive_root: Text, index: int): | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) | ||||
| @@ -141,7 +142,8 @@ class NASBench201API(object): | ||||
|   #  -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|   #  -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|   #  -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||
|   def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, | ||||
|                      use_12epochs_result: bool = False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) | ||||
| @@ -177,7 +179,7 @@ class NASBench201API(object): | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|   # return the topology structure of the `index`-th architecture | ||||
|   def arch(self, index): | ||||
|   def arch(self, index: int): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
| @@ -238,7 +240,7 @@ class NASBench201API(object): | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|   def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
| @@ -301,7 +303,7 @@ class NASBench201API(object): | ||||
|   If the index < 0: it will loop for all architectures and print their information one by one. | ||||
|   else: it will print the information of the 'index'-th archiitecture. | ||||
|   """ | ||||
|   def show(self, index=-1): | ||||
|   def show(self, index: int = -1) -> None: | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
|       for i, idx in enumerate(self.evaluated_indexes): | ||||
| @@ -336,8 +338,8 @@ class NASBench201API(object): | ||||
|   #   for i, node in enumerate(arch): | ||||
|   #     print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) | ||||
|   @staticmethod | ||||
|   def str2lists(xstr): | ||||
|     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|   def str2lists(xstr: Text) -> List[Any]: | ||||
|     # assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|     nodestrs = xstr.split('+') | ||||
|     genotypes = [] | ||||
|     for i, node_str in enumerate(nodestrs): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user