support get_net_config for NAS-Bench-201
This commit is contained in:
		| @@ -72,7 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 | ||||
| api.show(index) | ||||
| ``` | ||||
|  | ||||
| 5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | ||||
| 5. Create the network from api: | ||||
| ``` | ||||
| config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset | ||||
| from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models | ||||
| network = get_cell_based_tiny_net(config) # create the network from configurration | ||||
| print(network) # show the structure of this architecture | ||||
| ``` | ||||
| If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. | ||||
|  | ||||
| 6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | ||||
|  | ||||
|  | ||||
| ### Detailed Instruction | ||||
|   | ||||
| @@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
| @@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config): | ||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|   elif config.name == 'infer.tiny': | ||||
|     from .cell_infers import TinyNetwork | ||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) | ||||
|     if hasattr(config, 'genotype'): | ||||
|       genotype = config.genotype | ||||
|     elif hasattr(config, 'arch_str'): | ||||
|       genotype = CellStructure.str2structure(config.arch_str) | ||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|   else: | ||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||
|  | ||||
|   | ||||
| @@ -93,6 +93,8 @@ class NASBench201API(object): | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|  | ||||
|   # Overwrite all information of the 'index'-th architecture in the search space. | ||||
|   # It will load its data from 'archive_root'. | ||||
|   def reload(self, archive_root, index): | ||||
|     assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) | ||||
|     xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) | ||||
| @@ -123,9 +125,18 @@ class NASBench201API(object): | ||||
|       print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) | ||||
|       return None | ||||
|  | ||||
|   # query information with the training of 12 epochs or 200 epochs | ||||
|   # if dataname is None, return the ArchResults | ||||
|   # This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs. | ||||
|   # ------ | ||||
|   # If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config) | ||||
|   # If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config) | ||||
|   # ------ | ||||
|   # If dataname is None, return the ArchResults | ||||
|   # else, return a dict with all trials on that dataset (the key is the seed) | ||||
|   # Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'. | ||||
|   #  -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|   #  -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|   #  -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|   #  -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|   def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
| @@ -166,13 +177,41 @@ class NASBench201API(object): | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
|     return copy.deepcopy(self.meta_archs[index]) | ||||
|  | ||||
|   # obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||
|   """ | ||||
|   This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` | ||||
|   Args [seed]: | ||||
|     -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights. | ||||
|     -- a interger : return the weights of a specific trial, whose seed is this interger. | ||||
|   Args [use_12epochs_result]: | ||||
|     -- True : train the model by 12 epochs | ||||
|     -- False : train the model by 200 epochs | ||||
|   """ | ||||
|   def get_net_param(self, index, dataset, seed, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     return archresult.get_net_param(dataset, seed) | ||||
|    | ||||
|   """ | ||||
|   This function is used to obtain the configuration for the `index`-th architecture on `dataset`. | ||||
|   Args [dataset] (4 possible options): | ||||
|     -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|     -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|     -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|     -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|   This function will return a dict. | ||||
|   ========= Some examlpes for using this function: | ||||
|   config = api.get_net_config(128, 'cifar10') | ||||
|   """ | ||||
|   def get_net_config(self, index, dataset): | ||||
|     archresult = self.arch2infos_full[index] | ||||
|     all_results = archresult.query(dataset, None) | ||||
|     if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset)) | ||||
|     for seed, result in all_results.items(): | ||||
|       return result.get_config(None) | ||||
|       #print ('SEED [{:}] : {:}'.format(seed, result)) | ||||
|     raise ValueError('Impossible to reach here!') | ||||
|  | ||||
|   # obtain the cost metric for the `index`-th architecture on a dataset | ||||
|   def get_cost_info(self, index, dataset, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
| @@ -333,6 +372,7 @@ class NASBench201API(object): | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| class ArchResults(object): | ||||
|  | ||||
|   def __init__(self, arch_index, arch_str): | ||||
| @@ -615,11 +655,16 @@ class ResultsCount(object): | ||||
|   def get_net_param(self): | ||||
|     return self.net_state_dict | ||||
|  | ||||
|   # This function is used to obtain the config dict for this architecture. | ||||
|   def get_config(self, str2structure): | ||||
|     #return copy.deepcopy(self.arch_config) | ||||
|     return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||
|             'N'   : self.arch_config['num_cells'], \ | ||||
|             'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} | ||||
|     if str2structure is None: | ||||
|       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||
|               'N'   : self.arch_config['num_cells'], \ | ||||
|               'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']} | ||||
|     else: | ||||
|       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ | ||||
|               'N'   : self.arch_config['num_cells'], \ | ||||
|               'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} | ||||
|  | ||||
|   def state_dict(self): | ||||
|     _state_dict = {key: value for key, value in self.__dict__.items()} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user