diff --git a/.gitignore b/.gitignore index 9d3d2e9..ac31d15 100644 --- a/.gitignore +++ b/.gitignore @@ -116,7 +116,8 @@ cal.sh aaa cx.sh -NAS-Bench-102-v1_0.pth -lib/NAS-Bench-102-v1_0.pth +NAS-Bench-*-v1_0.pth +lib/NAS-Bench-*-v1_0.pth others/TF scripts-search/l2s-algos +TEMP-L.sh diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md index 188a5b1..31cdd82 100644 --- a/docs/NAS-Bench-201.md +++ b/docs/NAS-Bench-201.md @@ -72,7 +72,7 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 api.show(index) ``` -5. For other usages, please see `lib/nas_201_api/api.py` +5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. ### Detailed Instruction diff --git a/lib/nas_201_api/api.py b/lib/nas_201_api/api.py index a78412a..2222e71 100644 --- a/lib/nas_201_api/api.py +++ b/lib/nas_201_api/api.py @@ -77,6 +77,12 @@ class NASBench201API(object): def random(self): return random.randint(0, len(self.meta_archs)-1) + # This function is used to query the index of an architecture in the search space. + # The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|' + # or an instance that has the 'tostr' function that can generate the architecture string. + # This function will return the index. + # If return -1, it means this architecture is not in the search space. + # Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space). def query_index_by_arch(self, arch): if isinstance(arch, str): if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] @@ -97,6 +103,11 @@ class NASBench201API(object): self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) + # This function is used to query the information of a specific archiitecture + # 'arch' can be an architecture index or an architecture string + # When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config' + # When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config' + # The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config. def query_by_arch(self, arch, use_12epochs_result=False): if isinstance(arch, int): arch_index = arch @@ -354,6 +365,37 @@ class ArchResults(object): else: info[key] = None return info + """ + This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset. + If not specify, each set refer to the proposed split in NAS-Bench-201 paper. + If some args return None or raise error, then it is not avaliable. + ======================================== + Args [dataset] (4 possible options): + -- cifar10-valid : training the model on the CIFAR-10 training set. + -- cifar10 : training the model on the CIFAR-10 training + validation set. + -- cifar100 : training the model on the CIFAR-100 training set. + -- ImageNet16-120 : training the model on the ImageNet16-120 training set. + Args [setname] (each dataset has different setnames): + -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' + ------ 'train' : the metric on the training set. + ------ 'x-valid' : the metric on the validation set. + ------ 'ori-test' : the metric on the validation + test set. + -- When dataset = cifar10, you can use 'train', 'ori-test'. + ------ 'train' : the metric on the training + validation set. + ------ 'ori-test' : the metric on the test set. + -- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test' + ------ 'train' : the metric on the training set. + ------ 'x-valid' : the metric on the validation set. + ------ 'x-test' : the metric on the test set. + ------ 'ori-test' : the metric on the validation + test set. + Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs) + ------ None : return the metric after the last training epoch. + ------ an integer i : return the metric after the i-th training epoch. + Args [is_random]: + ------ True : return the metric of a randomly selected trial. + ------ False : return the averaged metric of all avaliable trials. + ------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random'). + """ def get_metrics(self, dataset, setname, iepoch=None, is_random=False): x_seeds = self.dataset_seed[dataset] results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]