From c7a54fd08bd48dc697e48529dc4aed7cd4f17f25 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 13 Oct 2020 00:06:31 +1100 Subject: [PATCH] Update README --- docs/NAS-Bench-201-PURE.md | 15 ++++++++------- docs/NAS-Bench-201.md | 15 ++++++++------- docs/NATS-Bench.md | 14 ++++++++++---- exps/experimental/vis-nats-bench-ws.py | 2 +- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/docs/NAS-Bench-201-PURE.md b/docs/NAS-Bench-201-PURE.md index e9980cb..896842d 100644 --- a/docs/NAS-Bench-201-PURE.md +++ b/docs/NAS-Bench-201-PURE.md @@ -70,17 +70,18 @@ api.show(2) # show the mean loss and accuracy of an architecture info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys -cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency +cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency # get the detailed information results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) -print ('Latency : {:}'.format(results[0].get_latency())) -print ('Train Info : {:}'.format(results[0].get_train())) -print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) -print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) -# for the metric after a specific epoch -print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) +for seed, result in results.items(): + print ('Latency : {:}'.format(result.get_latency())) + print ('Train Info : {:}'.format(result.get_train())) + print ('Valid Info : {:}'.format(result.get_eval('x-valid'))) + print ('Test Info : {:}'.format(result.get_eval('x-test'))) + # for the metric after a specific epoch + print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10))) ``` 4. Query the index of an architecture by string diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md index d4325fc..4a50e2f 100644 --- a/docs/NAS-Bench-201.md +++ b/docs/NAS-Bench-201.md @@ -68,17 +68,18 @@ api.show(2) # show the mean loss and accuracy of an architecture info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys -cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency +cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency # get the detailed information results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) -print ('Latency : {:}'.format(results[0].get_latency())) -print ('Train Info : {:}'.format(results[0].get_train())) -print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) -print ('Test Info : {:}'.format(results[0].get_eval('x-test'))) -# for the metric after a specific epoch -print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) +for seed, result in results.items(): + print ('Latency : {:}'.format(result.get_latency())) + print ('Train Info : {:}'.format(result.get_train())) + print ('Valid Info : {:}'.format(result.get_eval('x-valid'))) + print ('Test Info : {:}'.format(result.get_eval('x-test'))) + # for the metric after a specific epoch + print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10))) ``` 4. Query the index of an architecture by string diff --git a/docs/NATS-Bench.md b/docs/NATS-Bench.md index bdfc4b0..48ebc81 100644 --- a/docs/NATS-Bench.md +++ b/docs/NATS-Bench.md @@ -11,7 +11,7 @@ This facilitates a much larger community of researchers to focus on developing b The structure of this Markdown file: - [How to use NATS-Bench?](#How-to-Use-NATS-Bench) - [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch) -- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nas-bench-201) +- [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nats-bench) ## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf) @@ -77,8 +77,12 @@ params = api.get_net_param(12, 'cifar10', None) network.load_state_dict(next(iter(params.values()))) ``` + + ## How to Re-create NATS-Bench from Scratch +You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch. + ### The Size Search Space The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`. @@ -108,7 +112,9 @@ python exps/NATS-Bench/tss-collect.py ``` -## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201 +## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench + +You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods. ### Reproduce NAS methods on the topology search space @@ -169,14 +175,14 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 -Run the search strategy in FBNet-V2 +Run the channel search strategy in FBNet-V2 python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 -Run the search strategy in TuNAS: +Run the channel search strategy in TuNAS: python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 diff --git a/exps/experimental/vis-nats-bench-ws.py b/exps/experimental/vis-nats-bench-ws.py index b1d5014..8b01ebe 100644 --- a/exps/experimental/vis-nats-bench-ws.py +++ b/exps/experimental/vis-nats-bench-ws.py @@ -43,7 +43,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) - alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) + alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) for alg, name in alg2name.items():