add necessary comments for NAS-Bench-201.md
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -116,7 +116,8 @@ cal.sh | |||||||
| aaa | aaa | ||||||
| cx.sh | cx.sh | ||||||
|  |  | ||||||
| NAS-Bench-102-v1_0.pth | NAS-Bench-*-v1_0.pth | ||||||
| lib/NAS-Bench-102-v1_0.pth | lib/NAS-Bench-*-v1_0.pth | ||||||
| others/TF | others/TF | ||||||
| scripts-search/l2s-algos | scripts-search/l2s-algos | ||||||
|  | TEMP-L.sh | ||||||
|   | |||||||
| @@ -72,7 +72,7 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 | |||||||
| api.show(index) | api.show(index) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 5. For other usages, please see `lib/nas_201_api/api.py` | 5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. | ||||||
|  |  | ||||||
|  |  | ||||||
| ### Detailed Instruction | ### Detailed Instruction | ||||||
|   | |||||||
| @@ -77,6 +77,12 @@ class NASBench201API(object): | |||||||
|   def random(self): |   def random(self): | ||||||
|     return random.randint(0, len(self.meta_archs)-1) |     return random.randint(0, len(self.meta_archs)-1) | ||||||
|  |  | ||||||
|  |   # This function is used to query the index of an architecture in the search space. | ||||||
|  |   # The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|' | ||||||
|  |   # or an instance that has the 'tostr' function that can generate the architecture string. | ||||||
|  |   # This function will return the index. | ||||||
|  |   #   If return -1, it means this architecture is not in the search space. | ||||||
|  |   #   Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space). | ||||||
|   def query_index_by_arch(self, arch): |   def query_index_by_arch(self, arch): | ||||||
|     if isinstance(arch, str): |     if isinstance(arch, str): | ||||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] |       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||||
| @@ -97,6 +103,11 @@ class NASBench201API(object): | |||||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) |     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) |     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||||
|    |    | ||||||
|  |   # This function is used to query the information of a specific archiitecture | ||||||
|  |   # 'arch' can be an architecture index or an architecture string | ||||||
|  |   # When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config' | ||||||
|  |   # When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config' | ||||||
|  |   # The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config. | ||||||
|   def query_by_arch(self, arch, use_12epochs_result=False): |   def query_by_arch(self, arch, use_12epochs_result=False): | ||||||
|     if isinstance(arch, int): |     if isinstance(arch, int): | ||||||
|       arch_index = arch |       arch_index = arch | ||||||
| @@ -354,6 +365,37 @@ class ArchResults(object): | |||||||
|       else: info[key] = None |       else: info[key] = None | ||||||
|     return info |     return info | ||||||
|  |  | ||||||
|  |   """ | ||||||
|  |   This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset. | ||||||
|  |   If not specify, each set refer to the proposed split in NAS-Bench-201 paper. | ||||||
|  |   If some args return None or raise error, then it is not avaliable. | ||||||
|  |   ======================================== | ||||||
|  |   Args [dataset] (4 possible options): | ||||||
|  |     -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||||
|  |     -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||||
|  |     -- cifar100 : training the model on the CIFAR-100 training set. | ||||||
|  |     -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||||
|  |   Args [setname] (each dataset has different setnames): | ||||||
|  |     -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' | ||||||
|  |     ------ 'train' : the metric on the training set. | ||||||
|  |     ------ 'x-valid' : the metric on the validation set. | ||||||
|  |     ------ 'ori-test' : the metric on the validation + test set. | ||||||
|  |     -- When dataset = cifar10, you can use 'train', 'ori-test'. | ||||||
|  |     ------ 'train' : the metric on the training + validation set. | ||||||
|  |     ------ 'ori-test' : the metric on the test set. | ||||||
|  |     -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test' | ||||||
|  |     ------ 'train' : the metric on the training set. | ||||||
|  |     ------ 'x-valid' : the metric on the validation set. | ||||||
|  |     ------ 'x-test' : the metric on the test set. | ||||||
|  |     ------ 'ori-test' : the metric on the validation + test set. | ||||||
|  |   Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs) | ||||||
|  |     ------ None : return the metric after the last training epoch. | ||||||
|  |     ------ an integer i : return the metric after the i-th training epoch. | ||||||
|  |   Args [is_random]: | ||||||
|  |     ------ True : return the metric of a randomly selected trial. | ||||||
|  |     ------ False : return the averaged metric of all avaliable trials. | ||||||
|  |     ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random'). | ||||||
|  |   """ | ||||||
|   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): |   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): | ||||||
|     x_seeds = self.dataset_seed[dataset] |     x_seeds = self.dataset_seed[dataset] | ||||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] |     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user