add necessary comments for NAS-Bench-201.md
This commit is contained in:
parent
eb77894bc9
commit
133fd21ecc
5
.gitignore
vendored
5
.gitignore
vendored
@ -116,7 +116,8 @@ cal.sh
|
||||
aaa
|
||||
cx.sh
|
||||
|
||||
NAS-Bench-102-v1_0.pth
|
||||
lib/NAS-Bench-102-v1_0.pth
|
||||
NAS-Bench-*-v1_0.pth
|
||||
lib/NAS-Bench-*-v1_0.pth
|
||||
others/TF
|
||||
scripts-search/l2s-algos
|
||||
TEMP-L.sh
|
||||
|
@ -72,7 +72,7 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
|
||||
api.show(index)
|
||||
```
|
||||
|
||||
5. For other usages, please see `lib/nas_201_api/api.py`
|
||||
5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
||||
|
||||
|
||||
### Detailed Instruction
|
||||
|
@ -77,6 +77,12 @@ class NASBench201API(object):
|
||||
def random(self):
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
# This function is used to query the index of an architecture in the search space.
|
||||
# The input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
|
||||
# or an instance that has the 'tostr' function that can generate the architecture string.
|
||||
# This function will return the index.
|
||||
# If return -1, it means this architecture is not in the search space.
|
||||
# Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
||||
def query_index_by_arch(self, arch):
|
||||
if isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
@ -97,6 +103,11 @@ class NASBench201API(object):
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
# This function is used to query the information of a specific archiitecture
|
||||
# 'arch' can be an architecture index or an architecture string
|
||||
# When use_12epochs_result=True, the hyper-parameters used to train a model are in 'configs/nas-benchmark/CIFAR.config'
|
||||
# When use_12epochs_result=False, the hyper-parameters used to train a model are in 'configs/nas-benchmark/LESS.config'
|
||||
# The difference between these two configurations are the number of training epochs, which is 200 in CIFAR.config and 12 in LESS.config.
|
||||
def query_by_arch(self, arch, use_12epochs_result=False):
|
||||
if isinstance(arch, int):
|
||||
arch_index = arch
|
||||
@ -354,6 +365,37 @@ class ArchResults(object):
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
"""
|
||||
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
||||
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
||||
If some args return None or raise error, then it is not avaliable.
|
||||
========================================
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
Args [setname] (each dataset has different setnames):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'x-test' : the metric on the test set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
||||
------ None : return the metric after the last training epoch.
|
||||
------ an integer i : return the metric after the i-th training epoch.
|
||||
Args [is_random]:
|
||||
------ True : return the metric of a randomly selected trial.
|
||||
------ False : return the averaged metric of all avaliable trials.
|
||||
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
||||
"""
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
|
Loading…
Reference in New Issue
Block a user