Update docs of NATS-Bench
This commit is contained in:
		| @@ -3,10 +3,10 @@ | ||||
| ########################################################################################################################################################### | ||||
| # Before run these commands, the files must be properly put. | ||||
| # | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120 | ||||
| # CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10 | ||||
| ########################################################################################################################################################### | ||||
| import os, gc, sys, math, argparse, psutil | ||||
| import numpy as np | ||||
| @@ -140,7 +140,7 @@ if __name__ == '__main__': | ||||
|   save_dir = Path(args.save_dir) | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   meta_file = Path(args.base_path + '.pth') | ||||
|   weight_dir = Path(args.base_path + '-archive') | ||||
|   weight_dir = Path(args.base_path + '-full') | ||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||
|   assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir) | ||||
|  | ||||
|   | ||||
| @@ -395,9 +395,9 @@ if __name__ == '__main__': | ||||
|   for xdata in datasets: | ||||
|     visualize_tss_info(api201, xdata, to_save_dir) | ||||
|  | ||||
|   api301 = create(None, 'size', verbose=True) | ||||
|   api_sss = create(None, 'size', verbose=True) | ||||
|   for xdata in datasets: | ||||
|     visualize_sss_info(api301, xdata, to_save_dir) | ||||
|     visualize_sss_info(api_sss, xdata, to_save_dir) | ||||
|  | ||||
|   visualize_info(None, to_save_dir, 'tss') | ||||
|   visualize_info(None, to_save_dir, 'sss') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user