support get_net_config for NAS-Bench-201
This commit is contained in:
		| @@ -72,7 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 | |||||||
| api.show(index) | api.show(index) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | 5. Create the network from api: | ||||||
|  | ``` | ||||||
|  | config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset | ||||||
|  | from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models | ||||||
|  | network = get_cell_based_tiny_net(config) # create the network from configurration | ||||||
|  | print(network) # show the structure of this architecture | ||||||
|  | ``` | ||||||
|  | If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. | ||||||
|  |  | ||||||
|  | 6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | ||||||
|  |  | ||||||
|  |  | ||||||
| ### Detailed Instruction | ### Detailed Instruction | ||||||
|   | |||||||
| @@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures | |||||||
|  |  | ||||||
| # Cell-based NAS Models | # Cell-based NAS Models | ||||||
| def get_cell_based_tiny_net(config): | def get_cell_based_tiny_net(config): | ||||||
|  |   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] |   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||||
|   if super_type == 'basic' and config.name in group_names: |   if super_type == 'basic' and config.name in group_names: | ||||||
| @@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config): | |||||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) |                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||||
|   elif config.name == 'infer.tiny': |   elif config.name == 'infer.tiny': | ||||||
|     from .cell_infers import TinyNetwork |     from .cell_infers import TinyNetwork | ||||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) |     if hasattr(config, 'genotype'): | ||||||
|  |       genotype = config.genotype | ||||||
|  |     elif hasattr(config, 'arch_str'): | ||||||
|  |       genotype = CellStructure.str2structure(config.arch_str) | ||||||
|  |     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||||
|  |     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) |     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -93,6 +93,8 @@ class NASBench201API(object): | |||||||
|     else: arch_index = -1 |     else: arch_index = -1 | ||||||
|     return arch_index |     return arch_index | ||||||
|  |  | ||||||
|  |   # Overwrite all information of the 'index'-th architecture in the search space. | ||||||
|  |   # It will load its data from 'archive_root'. | ||||||
|   def reload(self, archive_root, index): |   def reload(self, archive_root, index): | ||||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) |     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) |     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||||
| @@ -123,9 +125,18 @@ class NASBench201API(object): | |||||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) |       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||||
|       return None |       return None | ||||||
|  |  | ||||||
|   # query information with the training of 12 epochs or 200 epochs |   # This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs. | ||||||
|   # if dataname is None, return the ArchResults |   # ------ | ||||||
|  |   # If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config) | ||||||
|  |   # If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config) | ||||||
|  |   # ------ | ||||||
|  |   # If dataname is None, return the ArchResults | ||||||
|   # else, return a dict with all trials on that dataset (the key is the seed) |   # else, return a dict with all trials on that dataset (the key is the seed) | ||||||
|  |   # Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'. | ||||||
|  |   #  -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||||
|  |   #  -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||||
|  |   #  -- cifar100 : training the model on the CIFAR-100 training set. | ||||||
|  |   #  -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||||
|   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): |   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
| @@ -166,12 +177,40 @@ class NASBench201API(object): | |||||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) |     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||||
|     return copy.deepcopy(self.meta_archs[index]) |     return copy.deepcopy(self.meta_archs[index]) | ||||||
|  |  | ||||||
|   # obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` |   """ | ||||||
|  |   This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||||
|  |   Args [seed]: | ||||||
|  |     -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights. | ||||||
|  |     -- a interger : return the weights of a specific trial, whose seed is this interger. | ||||||
|  |   Args [use_12epochs_result]: | ||||||
|  |     -- True : train the model by 12 epochs | ||||||
|  |     -- False : train the model by 200 epochs | ||||||
|  |   """ | ||||||
|   def get_net_param(self, index, dataset, seed, use_12epochs_result=False): |   def get_net_param(self, index, dataset, seed, use_12epochs_result=False): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|     return archresult.get_net_param(dataset, seed) |     return archresult.get_net_param(dataset, seed) | ||||||
|  |    | ||||||
|  |   """ | ||||||
|  |   This function is used to obtain the configuration for the `index`-th architecture on `dataset`. | ||||||
|  |   Args [dataset] (4 possible options): | ||||||
|  |     -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||||
|  |     -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||||
|  |     -- cifar100 : training the model on the CIFAR-100 training set. | ||||||
|  |     -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||||
|  |   This function will return a dict. | ||||||
|  |   ========= Some examlpes for using this function: | ||||||
|  |   config = api.get_net_config(128, 'cifar10') | ||||||
|  |   """ | ||||||
|  |   def get_net_config(self, index, dataset): | ||||||
|  |     archresult = self.arch2infos_full[index] | ||||||
|  |     all_results = archresult.query(dataset, None) | ||||||
|  |     if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset)) | ||||||
|  |     for seed, result in all_results.items(): | ||||||
|  |       return result.get_config(None) | ||||||
|  |       #print ('SEED [{:}] : {:}'.format(seed, result)) | ||||||
|  |     raise ValueError('Impossible to reach here!') | ||||||
|  |  | ||||||
|   # obtain the cost metric for the `index`-th architecture on a dataset |   # obtain the cost metric for the `index`-th architecture on a dataset | ||||||
|   def get_cost_info(self, index, dataset, use_12epochs_result=False): |   def get_cost_info(self, index, dataset, use_12epochs_result=False): | ||||||
| @@ -333,6 +372,7 @@ class NASBench201API(object): | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ArchResults(object): | class ArchResults(object): | ||||||
|  |  | ||||||
|   def __init__(self, arch_index, arch_str): |   def __init__(self, arch_index, arch_str): | ||||||
| @@ -615,11 +655,16 @@ class ResultsCount(object): | |||||||
|   def get_net_param(self): |   def get_net_param(self): | ||||||
|     return self.net_state_dict |     return self.net_state_dict | ||||||
|  |  | ||||||
|  |   # This function is used to obtain the config dict for this architecture. | ||||||
|   def get_config(self, str2structure): |   def get_config(self, str2structure): | ||||||
|     #return copy.deepcopy(self.arch_config) |     if str2structure is None: | ||||||
|     return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ |       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||||
|             'N'   : self.arch_config['num_cells'], \ |               'N'   : self.arch_config['num_cells'], \ | ||||||
|             'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} |               'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']} | ||||||
|  |     else: | ||||||
|  |       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||||
|  |               'N'   : self.arch_config['num_cells'], \ | ||||||
|  |               'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} | ||||||
|  |  | ||||||
|   def state_dict(self): |   def state_dict(self): | ||||||
|     _state_dict = {key: value for key, value in self.__dict__.items()} |     _state_dict = {key: value for key, value in self.__dict__.items()} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user