add necessary comments for NAS-Bench-201.md
This commit is contained in:
		| @@ -77,6 +77,12 @@ class NASBench201API(object): | ||||
|   def random(self): | ||||
|     return random.randint(0, len(self.meta_archs)-1) | ||||
|  | ||||
|   # This function is used to query the index of an architecture in the search space. | ||||
|   # The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|' | ||||
|   # or an instance that has the 'tostr' function that can generate the architecture string. | ||||
|   # This function will return the index. | ||||
|   #   If return -1, it means this architecture is not in the search space. | ||||
|   #   Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space). | ||||
|   def query_index_by_arch(self, arch): | ||||
|     if isinstance(arch, str): | ||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||
| @@ -97,6 +103,11 @@ class NASBench201API(object): | ||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||
|    | ||||
|   # This function is used to query the information of a specific archiitecture | ||||
|   # 'arch' can be an architecture index or an architecture string | ||||
|   # When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config' | ||||
|   # When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config' | ||||
|   # The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config. | ||||
|   def query_by_arch(self, arch, use_12epochs_result=False): | ||||
|     if isinstance(arch, int): | ||||
|       arch_index = arch | ||||
| @@ -354,6 +365,37 @@ class ArchResults(object): | ||||
|       else: info[key] = None | ||||
|     return info | ||||
|  | ||||
|   """ | ||||
|   This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset. | ||||
|   If not specify, each set refer to the proposed split in NAS-Bench-201 paper. | ||||
|   If some args return None or raise error, then it is not avaliable. | ||||
|   ======================================== | ||||
|   Args [dataset] (4 possible options): | ||||
|     -- cifar10-valid : training the model on the CIFAR-10 training set. | ||||
|     -- cifar10 : training the model on the CIFAR-10 training + validation set. | ||||
|     -- cifar100 : training the model on the CIFAR-100 training set. | ||||
|     -- ImageNet16-120 : training the model on the ImageNet16-120 training set. | ||||
|   Args [setname] (each dataset has different setnames): | ||||
|     -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' | ||||
|     ------ 'train' : the metric on the training set. | ||||
|     ------ 'x-valid' : the metric on the validation set. | ||||
|     ------ 'ori-test' : the metric on the validation + test set. | ||||
|     -- When dataset = cifar10, you can use 'train', 'ori-test'. | ||||
|     ------ 'train' : the metric on the training + validation set. | ||||
|     ------ 'ori-test' : the metric on the test set. | ||||
|     -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test' | ||||
|     ------ 'train' : the metric on the training set. | ||||
|     ------ 'x-valid' : the metric on the validation set. | ||||
|     ------ 'x-test' : the metric on the test set. | ||||
|     ------ 'ori-test' : the metric on the validation + test set. | ||||
|   Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs) | ||||
|     ------ None : return the metric after the last training epoch. | ||||
|     ------ an integer i : return the metric after the i-th training epoch. | ||||
|   Args [is_random]: | ||||
|     ------ True : return the metric of a randomly selected trial. | ||||
|     ------ False : return the averaged metric of all avaliable trials. | ||||
|     ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random'). | ||||
|   """ | ||||
|   def get_metrics(self, dataset, setname, iepoch=None, is_random=False): | ||||
|     x_seeds = self.dataset_seed[dataset] | ||||
|     results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user