From 7d02870bf8b5bf8290bde7cc470c3d11d3a9a800 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 5 Jul 2020 23:14:15 +0000 Subject: [PATCH] Update weight watcher codes --- CHANGE-LOG.md | 1 + exps/experimental/test-ww-bench.py | 1 + lib/nas_201_api/api_utils.py | 6 +++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGE-LOG.md b/CHANGE-LOG.md index a2d8df0..0aad0b8 100644 --- a/CHANGE-LOG.md +++ b/CHANGE-LOG.md @@ -4,3 +4,4 @@ - [2019.12.20] [69ca086] Release NAS-Bench-201. - [2019.09.28] [f8f3f38] TAS and SETN codes were publicly released. - [2019.01.31] [13e908f] GDAS codes were publicly released. +- [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index 570c10c..4351571 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -9,6 +9,7 @@ # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10 # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100 # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120 +# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10 ########################################################################################################################################################### import os, gc, sys, math, argparse, psutil import numpy as np diff --git a/lib/nas_201_api/api_utils.py b/lib/nas_201_api/api_utils.py index 904b825..428fab1 100644 --- a/lib/nas_201_api/api_utils.py +++ b/lib/nas_201_api/api_utils.py @@ -411,7 +411,11 @@ class ArchResults(object): x_seeds = self.dataset_seed[dataset] return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds} else: - return self.all_results[(dataset, seed)].get_net_param() + xkey = (dataset, seed) + if xkey in self.all_results: + return self.all_results[xkey].get_net_param() + else: + raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys()))) def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None: """This function is used to reset the latency in all corresponding ResultsCount(s)."""