add necessary comments for NAS-Bench-201.md

This commit is contained in:
D-X-Y 2020-02-02 16:19:37 +11:00
parent eb77894bc9
commit 133fd21ecc
3 changed files with 46 additions and 3 deletions

5
.gitignore vendored
View File

@ -116,7 +116,8 @@ cal.sh
aaa
cx.sh
NAS-Bench-102-v1_0.pth
lib/NAS-Bench-102-v1_0.pth
NAS-Bench-*-v1_0.pth
lib/NAS-Bench-*-v1_0.pth
others/TF
scripts-search/l2s-algos
TEMP-L.sh

View File

@ -72,7 +72,7 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
api.show(index)
```
5. For other usages, please see `lib/nas_201_api/api.py`
5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
### Detailed Instruction

View File

@ -77,6 +77,12 @@ class NASBench201API(object):
def random(self):
return random.randint(0, len(self.meta_archs)-1)
# This function is used to query the index of an architecture in the search space.
# The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
# or an instance that has the 'tostr' function that can generate the architecture string.
# This function will return the index.
# If return -1, it means this architecture is not in the search space.
# Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
def query_index_by_arch(self, arch):
if isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
@ -97,6 +103,11 @@ class NASBench201API(object):
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
# This function is used to query the information of a specific archiitecture
# 'arch' can be an architecture index or an architecture string
# When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config'
# When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config'
# The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config.
def query_by_arch(self, arch, use_12epochs_result=False):
if isinstance(arch, int):
arch_index = arch
@ -354,6 +365,37 @@ class ArchResults(object):
else: info[key] = None
return info
"""
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
If some args return None or raise error, then it is not avaliable.
========================================
Args [dataset] (4 possible options):
-- cifar10-valid : training the model on the CIFAR-10 training set.
-- cifar10 : training the model on the CIFAR-10 training + validation set.
-- cifar100 : training the model on the CIFAR-100 training set.
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
Args [setname] (each dataset has different setnames):
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'ori-test' : the metric on the validation + test set.
-- When dataset = cifar10, you can use 'train', 'ori-test'.
------ 'train' : the metric on the training + validation set.
------ 'ori-test' : the metric on the test set.
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
------ 'train' : the metric on the training set.
------ 'x-valid' : the metric on the validation set.
------ 'x-test' : the metric on the test set.
------ 'ori-test' : the metric on the validation + test set.
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
------ None : return the metric after the last training epoch.
------ an integer i : return the metric after the i-th training epoch.
Args [is_random]:
------ True : return the metric of a randomly selected trial.
------ False : return the averaged metric of all avaliable trials.
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
"""
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]