Update NATS-Bench (sss version 1.2)
This commit is contained in:
parent
469a207945
commit
5f151d1970
@ -13,6 +13,14 @@ This facilitates a much larger community of researchers to focus on developing b
|
|||||||
|
|
||||||
## How to Use NATS-Bench
|
## How to Use NATS-Bench
|
||||||
|
|
||||||
|
### Preparation and Download
|
||||||
|
The **latest** benchmark file of NATS-Bench can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zjB6wMANiKwB2A1yil2hQ8H_qyeSe2yt?usp=sharing).
|
||||||
|
|
||||||
|
1, create the benchmark instance:
|
||||||
|
```
|
||||||
|
api = create(None, 'sss', fast_mode=True, verbose=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## The Procedure of Creating NATS-Bench
|
## The Procedure of Creating NATS-Bench
|
||||||
|
|
||||||
@ -36,34 +44,34 @@ The checkpoint of all candidates are located at `output/NATS-Bench-size` by defa
|
|||||||
|
|
||||||
```
|
```
|
||||||
DARTS (V1):
|
DARTS (V1):
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
||||||
|
|
||||||
DARTS (V2):
|
DARTS (V2):
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
||||||
|
|
||||||
GDAS:
|
GDAS:
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16
|
||||||
|
|
||||||
SETN:
|
SETN:
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
||||||
|
|
||||||
Random Search with Weight Sharing:
|
Random Search with Weight Sharing:
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||||
|
|
||||||
ENAS:
|
ENAS:
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001
|
||||||
```
|
```
|
||||||
|
|
||||||
### Reproduce NAS methods on the size search space
|
### Reproduce NAS methods on the size search space
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
# python exps/NATS-Bench/sss-collect.py #
|
# python exps/NATS-Bench/sss-collect.py #
|
||||||
##############################################################################
|
##############################################################################
|
||||||
import os, re, sys, time, shutil, argparse, collections
|
import os, re, sys, time, shutil, argparse, collections
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -22,7 +21,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from config_utils import dict2config
|
from config_utils import dict2config
|
||||||
from models import CellStructure, get_cell_based_tiny_net
|
from models import CellStructure, get_cell_based_tiny_net
|
||||||
from nas_201_api import ArchResults, ResultsCount
|
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
|
||||||
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
||||||
from utils import get_md5_file
|
from utils import get_md5_file
|
||||||
|
|
||||||
@ -193,8 +192,8 @@ def simplify(save_dir, save_name, nets, total):
|
|||||||
arch_str = nets[index]
|
arch_str = nets[index]
|
||||||
hp2info = OrderedDict()
|
hp2info = OrderedDict()
|
||||||
|
|
||||||
full_save_path = full_save_dir / '{:06d}.npy'.format(index)
|
full_save_path = full_save_dir / '{:06d}.pickle'.format(index)
|
||||||
simple_save_path = simple_save_dir / '{:06d}.npy'.format(index)
|
simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
|
||||||
|
|
||||||
for hp in hps:
|
for hp in hps:
|
||||||
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
|
sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
|
||||||
@ -213,13 +212,13 @@ def simplify(save_dir, save_name, nets, total):
|
|||||||
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
||||||
'12': hp2info['12'].state_dict(),
|
'12': hp2info['12'].state_dict(),
|
||||||
'90': hp2info['90'].state_dict()})
|
'90': hp2info['90'].state_dict()})
|
||||||
np.save(str(full_save_path), to_save_data)
|
pickle_save(to_save_data, str(full_save_path))
|
||||||
|
|
||||||
for hp in hps: hp2info[hp].clear_params()
|
for hp in hps: hp2info[hp].clear_params()
|
||||||
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
to_save_data = OrderedDict({'01': hp2info['01'].state_dict(),
|
||||||
'12': hp2info['12'].state_dict(),
|
'12': hp2info['12'].state_dict(),
|
||||||
'90': hp2info['90'].state_dict()})
|
'90': hp2info['90'].state_dict()})
|
||||||
np.save(str(simple_save_path), to_save_data)
|
pickle_save(to_save_data, str(simple_save_path))
|
||||||
arch2infos[index] = to_save_data
|
arch2infos[index] = to_save_data
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
arch_time.update(time.time() - end_time)
|
arch_time.update(time.time() - end_time)
|
||||||
@ -231,18 +230,23 @@ def simplify(save_dir, save_name, nets, total):
|
|||||||
'total_archs': total,
|
'total_archs': total,
|
||||||
'arch2infos' : arch2infos,
|
'arch2infos' : arch2infos,
|
||||||
'evaluated_indexes': evaluated_indexes}
|
'evaluated_indexes': evaluated_indexes}
|
||||||
save_file_name = save_dir / '{:}.npy'.format(save_name)
|
save_file_name = save_dir / '{:}.pickle'.format(save_name)
|
||||||
np.save(str(save_file_name), final_infos)
|
pickle_save(final_infos, str(save_file_name))
|
||||||
# move the benchmark file to a new path
|
# move the benchmark file to a new path
|
||||||
hd5sum = get_md5_file(save_file_name)
|
hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
|
||||||
hd5_file_name = save_dir / '{:}-{:}.npy'.format(NATS_TSS_BASE_NAME, hd5sum)
|
hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
|
||||||
shutil.move(save_file_name, hd5_file_name)
|
shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
|
||||||
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
|
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
|
||||||
# move the directory to a new path
|
# move the directory to a new path
|
||||||
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
|
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
|
||||||
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
|
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
|
||||||
shutil.move(full_save_dir, hd5_full_save_dir)
|
shutil.move(full_save_dir, hd5_full_save_dir)
|
||||||
shutil.move(simple_save_dir, hd5_simple_save_dir)
|
shutil.move(simple_save_dir, hd5_simple_save_dir)
|
||||||
|
# save the meta information for simple and full
|
||||||
|
final_infos['arch2infos'] = None
|
||||||
|
final_infos['evaluated_indexes'] = set()
|
||||||
|
pickle_save(final_infos, str(hd5_full_save_dir / 'meta.pickle'))
|
||||||
|
pickle_save(final_infos, str(hd5_simple_save_dir / 'meta.pickle'))
|
||||||
|
|
||||||
|
|
||||||
def traverse_net(candidates: List[int], N: int):
|
def traverse_net(candidates: List[int], N: int):
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
###############################################################
|
##############################################################################
|
||||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||||
###############################################################
|
##############################################################################
|
||||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 #
|
||||||
###############################################################
|
##############################################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
# Usage: python exps/NATS-Bench/test-nats-api.py #
|
||||||
###############################################################
|
##############################################################################
|
||||||
# Usage: python exps/NAS-Bench-201/test-nas-api.py #
|
|
||||||
###############################################################
|
|
||||||
import os, sys, time, torch, argparse
|
import os, sys, time, torch, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Text, Dict, Any
|
from typing import List, Text, Dict, Any
|
||||||
@ -61,10 +59,12 @@ def test_api(api, is_301=True):
|
|||||||
print('{:}\n'.format(info))
|
print('{:}\n'.format(info))
|
||||||
info = api.get_latency(12, 'cifar10')
|
info = api.get_latency(12, 'cifar10')
|
||||||
print('{:}\n'.format(info))
|
print('{:}\n'.format(info))
|
||||||
|
for index in [13, 15, 19, 200]:
|
||||||
|
info = api.get_latency(index, 'cifar10')
|
||||||
|
|
||||||
# Count the number of architectures
|
# Count the number of architectures
|
||||||
info = api.statistics('cifar100', '12')
|
info = api.statistics('cifar100', '12')
|
||||||
print('{:}\n'.format(info))
|
print('{:} statistics results : {:}\n'.format(time_string(), info))
|
||||||
|
|
||||||
# Show the information of the 123-th architecture
|
# Show the information of the 123-th architecture
|
||||||
api.show(123)
|
api.show(123)
|
||||||
@ -80,33 +80,18 @@ def test_api(api, is_301=True):
|
|||||||
print('Compute the adjacency matrix of {:}'.format(arch_str))
|
print('Compute the adjacency matrix of {:}'.format(arch_str))
|
||||||
print(matrix)
|
print(matrix)
|
||||||
info = api.simulate_train_eval(123, 'cifar10')
|
info = api.simulate_train_eval(123, 'cifar10')
|
||||||
print('simulate_train_eval : {:}'.format(info))
|
print('simulate_train_eval : {:}\n\n'.format(info))
|
||||||
|
|
||||||
|
|
||||||
def test_issue_81_82(api):
|
|
||||||
results = api.query_by_index(0, 'cifar10-valid', hp='12')
|
|
||||||
results = api.query_by_index(0, 'cifar10-valid', hp='200')
|
|
||||||
print(list(results.keys()))
|
|
||||||
print(results[888].get_eval('valid'))
|
|
||||||
print(results[888].get_eval('x-valid'))
|
|
||||||
result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False)
|
|
||||||
info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200')
|
|
||||||
print(info)
|
|
||||||
structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|')
|
|
||||||
info = api.query_by_arch(structure, '200')
|
|
||||||
print(info)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
api201 = create(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), 'topology', True)
|
for fast_mode in [True, False]:
|
||||||
test_issue_81_82(api201)
|
for verbose in [True, False]:
|
||||||
print ('Test {:} done'.format(api201))
|
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
|
||||||
|
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
|
||||||
|
print('{:} --->>> {:}'.format(time_string(), api301))
|
||||||
|
test_api(api301, True)
|
||||||
|
|
||||||
api201 = create(None, 'topology', True) # use the default file path
|
# api201 = create(None, 'topology', True) # use the default file path
|
||||||
test_issue_81_82(api201)
|
# test_api(api201, False)
|
||||||
test_api(api201, False)
|
# print ('Test {:} done'.format(api201))
|
||||||
print ('Test {:} done'.format(api201))
|
|
||||||
|
|
||||||
api301 = create(None, 'size', True)
|
|
||||||
test_api(api301, True)
|
|
@ -5,8 +5,8 @@
|
|||||||
# required to install hpbandster ##################################
|
# required to install hpbandster ##################################
|
||||||
# pip install hpbandster ##################################
|
# pip install hpbandster ##################################
|
||||||
###################################################################
|
###################################################################
|
||||||
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||||
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||||
###################################################################
|
###################################################################
|
||||||
import os, sys, time, random, argparse, collections
|
import os, sys, time, random, argparse, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -167,7 +167,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
api = create(None, args.search_space, verbose=False)
|
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||||
|
|
||||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB')
|
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB')
|
||||||
print('save-dir : {:}'.format(args.save_dir))
|
print('save-dir : {:}'.format(args.save_dir))
|
@ -3,9 +3,9 @@
|
|||||||
##############################################################################
|
##############################################################################
|
||||||
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
|
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# python ./exps/algos-v2/random_wo_share.py --dataset cifar10 --search_space tss
|
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss
|
||||||
# python ./exps/algos-v2/random_wo_share.py --dataset cifar100 --search_space tss
|
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
|
||||||
# python ./exps/algos-v2/random_wo_share.py --dataset ImageNet16-120 --search_space tss
|
# python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss
|
||||||
##############################################################################
|
##############################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
||||||
@ -71,7 +71,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
api = create(None, args.search_space, verbose=False)
|
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||||
|
|
||||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM')
|
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM')
|
||||||
print('save-dir : {:}'.format(args.save_dir))
|
print('save-dir : {:}'.format(args.save_dir))
|
@ -3,12 +3,12 @@
|
|||||||
##################################################################
|
##################################################################
|
||||||
# Regularized Evolution for Image Classifier Architecture Search #
|
# Regularized Evolution for Image Classifier Architecture Search #
|
||||||
##################################################################
|
##################################################################
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||||
##################################################################
|
##################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
||||||
@ -198,7 +198,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
api = create(None, args.search_space, verbose=False)
|
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||||
|
|
||||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
|
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
|
||||||
print('save-dir : {:}'.format(args.save_dir))
|
print('save-dir : {:}'.format(args.save_dir))
|
@ -3,12 +3,12 @@
|
|||||||
#####################################################################################################
|
#####################################################################################################
|
||||||
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
|
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
|
||||||
#####################################################################################################
|
#####################################################################################################
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
|
||||||
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
|
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
|
||||||
#####################################################################################################
|
#####################################################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
47
exps/NATS-algos/run-all.sh
Normal file
47
exps/NATS-algos/run-all.sh
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./exps/NATS-algos/run-all.sh mul
|
||||||
|
# bash ./exps/NATS-algos/run-all.sh ws
|
||||||
|
set -e
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 1 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 1 parameters for type of algorithms."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
datasets="cifar10 cifar100 ImageNet16-120"
|
||||||
|
alg_type=$1
|
||||||
|
|
||||||
|
if [ "$alg_type" == "mul" ]; then
|
||||||
|
search_spaces="tss sss"
|
||||||
|
|
||||||
|
for dataset in ${datasets}
|
||||||
|
do
|
||||||
|
for search_space in ${search_spaces}
|
||||||
|
do
|
||||||
|
python ./exps/NATS-algos/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
||||||
|
python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||||
|
python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
||||||
|
python ./exps/NATS-algos/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||||
|
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||||
|
else
|
||||||
|
seeds="777 888 999"
|
||||||
|
algos="darts-v1 darts-v2 gdas setn random enas"
|
||||||
|
epoch=200
|
||||||
|
for seed in ${seeds}
|
||||||
|
do
|
||||||
|
for alg in ${algos}
|
||||||
|
do
|
||||||
|
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
@ -1,29 +1,29 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
######################################################################################
|
######################################################################################
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||||
######################################################################################
|
######################################################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -364,7 +364,7 @@ def main(xargs):
|
|||||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||||
logger.log('search-space : {:}'.format(search_space))
|
logger.log('search-space : {:}'.format(search_space))
|
||||||
if bool(xargs.use_api):
|
if bool(xargs.use_api):
|
||||||
api = create(None, 'topology', verbose=False)
|
api = create(None, 'topology', fast_mode=True, verbose=False)
|
||||||
else:
|
else:
|
||||||
api = None
|
api = None
|
||||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
@ -1,17 +1,17 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
######################################################################################
|
######################################################################################
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||||
# python ./exps/algos-v2/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||||
# python ./exps/algos-v2/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||||
######################################################################################
|
######################################################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -176,7 +176,7 @@ def main(xargs):
|
|||||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||||
logger.log('search-space : {:}'.format(search_space))
|
logger.log('search-space : {:}'.format(search_space))
|
||||||
if bool(xargs.use_api):
|
if bool(xargs.use_api):
|
||||||
api = create(None, 'size', verbose=False)
|
api = create(None, 'size', fast_mode=True, verbose=False)
|
||||||
else:
|
else:
|
||||||
api = None
|
api = None
|
||||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
@ -1,47 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# bash ./exps/algos-v2/run-all.sh mul
|
|
||||||
# bash ./exps/algos-v2/run-all.sh ws
|
|
||||||
set -e
|
|
||||||
echo script name: $0
|
|
||||||
echo $# arguments
|
|
||||||
if [ "$#" -ne 1 ] ;then
|
|
||||||
echo "Input illegal number of parameters " $#
|
|
||||||
echo "Need 1 parameters for type of algorithms."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
datasets="cifar10 cifar100 ImageNet16-120"
|
|
||||||
alg_type=$1
|
|
||||||
|
|
||||||
if [ "$alg_type" == "mul" ]; then
|
|
||||||
search_spaces="tss sss"
|
|
||||||
|
|
||||||
for dataset in ${datasets}
|
|
||||||
do
|
|
||||||
for search_space in ${search_spaces}
|
|
||||||
do
|
|
||||||
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
|
||||||
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
|
||||||
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
|
||||||
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
|
||||||
done
|
|
||||||
done
|
|
||||||
|
|
||||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
|
||||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
|
||||||
else
|
|
||||||
seeds="777 888 999"
|
|
||||||
algos="darts-v1 darts-v2 gdas setn random enas"
|
|
||||||
epoch=200
|
|
||||||
for seed in ${seeds}
|
|
||||||
do
|
|
||||||
for alg in ${algos}
|
|
||||||
do
|
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
|
||||||
python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
|
||||||
python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
|
||||||
done
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
@ -1,25 +1,36 @@
|
|||||||
#####################################################
|
##############################################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 ##########################
|
||||||
#####################################################
|
##############################################################################
|
||||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||||
#####################################################
|
##############################################################################
|
||||||
#
|
# The official Application Programming Interface (API) for NATS-Bench. #
|
||||||
#
|
##############################################################################
|
||||||
|
from .api_utils import pickle_save, pickle_load
|
||||||
from .api_utils import ArchResults, ResultsCount
|
from .api_utils import ArchResults, ResultsCount
|
||||||
from .api_topology import NATStopology
|
from .api_topology import NATStopology
|
||||||
from .api_size import NATSsize
|
from .api_size import NATSsize
|
||||||
|
|
||||||
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.07.30]
|
|
||||||
|
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.28]
|
||||||
|
|
||||||
|
|
||||||
def version():
|
def version():
|
||||||
return NATS_BENCH_API_VERSIONs[-1]
|
return NATS_BENCH_API_VERSIONs[-1]
|
||||||
|
|
||||||
|
|
||||||
def create(file_path_or_dict, search_space, verbose=True):
|
def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
|
||||||
|
"""Create the instead for NATS API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path_or_dict: None or a file path or a directory path.
|
||||||
|
search_space: This is a string indicates the search space in NATS-Bench.
|
||||||
|
fast_mode: If True, we will not load all the data at initialization, instead, the data for each candidate architecture will be loaded when quering it;
|
||||||
|
If False, we will load all the data during initialization.
|
||||||
|
verbose: This is a flag to indicate whether log additional information.
|
||||||
|
"""
|
||||||
if search_space in ['tss', 'topology']:
|
if search_space in ['tss', 'topology']:
|
||||||
return NATStopology(file_path_or_dict, verbose)
|
return NATStopology(file_path_or_dict, fast_mode, verbose)
|
||||||
elif search_space in ['sss', 'size']:
|
elif search_space in ['sss', 'size']:
|
||||||
return NATSsize(file_path_or_dict, verbose)
|
return NATSsize(file_path_or_dict, fast_mode, verbose)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid search space : {:}'.format(search_space))
|
raise ValueError('invalid search space : {:}'.format(search_space))
|
||||||
|
@ -1,21 +1,23 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||||
############################################################################################
|
##############################################################################
|
||||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||||
############################################################################################
|
#####################################################################################
|
||||||
# The history of benchmark files:
|
# The history of benchmark files (the name is NATS-tss-[version]-[md5].pickle.pbz2) #
|
||||||
#
|
# [2020.08.28] NATS-tss-v1_0-50262.pickle.pbz2 #
|
||||||
import os, copy, random, torch, numpy as np
|
#####################################################################################
|
||||||
|
import os, copy, random, numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Text, Union, Dict, Optional
|
from typing import List, Text, Union, Dict, Optional
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
from .api_utils import pickle_load
|
||||||
from .api_utils import ArchResults
|
from .api_utils import ArchResults
|
||||||
from .api_utils import NASBenchMetaAPI
|
from .api_utils import NASBenchMetaAPI
|
||||||
from .api_utils import remap_dataset_set_names
|
from .api_utils import remap_dataset_set_names
|
||||||
|
|
||||||
|
|
||||||
ALL_BENCHMARK_FILES = ['NAS-Bench-301-v1_0-363be7.pth']
|
PICKLE_EXT = 'pickle.pbz2'
|
||||||
ALL_ARCHIVE_DIRS = ['NAS-Bench-301-v1_0-archive']
|
ALL_BASE_NAMES = ['NATS-tss-v1_0-50262']
|
||||||
|
|
||||||
|
|
||||||
def print_information(information, extra_info=None, show=False):
|
def print_information(information, extra_info=None, show=False):
|
||||||
@ -54,42 +56,65 @@ This is the class for the API of size search space in NATS-Bench.
|
|||||||
class NATSsize(NASBenchMetaAPI):
|
class NATSsize(NASBenchMetaAPI):
|
||||||
|
|
||||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, fast_mode: bool=False, verbose: bool=True):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self._search_space_name = 'size'
|
self._search_space_name = 'size'
|
||||||
|
self._fast_mode = fast_mode
|
||||||
|
self._archive_dir = None
|
||||||
self.reset_time()
|
self.reset_time()
|
||||||
if file_path_or_dict is None:
|
if file_path_or_dict is None:
|
||||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
if self._fast_mode:
|
||||||
print ('Try to use the default NATS-Bench (size) path from {:}.'.format(file_path_or_dict))
|
self._archive_dir = os.path.join(os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
|
||||||
|
else:
|
||||||
|
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], '{:}.{:}'.format(ALL_BASE_NAMES[-1], PICKLE_EXT))
|
||||||
|
print ('Try to use the default NATS-Bench (size) path from fast_mode={:} and path={:}.'.format(self._fast_mode, file_path_or_dict))
|
||||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||||
file_path_or_dict = str(file_path_or_dict)
|
file_path_or_dict = str(file_path_or_dict)
|
||||||
if verbose: print('try to create the NATS-Bench (size) api from {:}'.format(file_path_or_dict))
|
if verbose:
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
print('Try to create the NATS-Bench (size) api from {:} with fast_mode={:}'.format(file_path_or_dict, fast_mode))
|
||||||
|
if not os.path.isfile(file_path_or_dict) and not os.path.isdir(file_path_or_dict):
|
||||||
|
raise ValueError('{:} is neither a file or a dir.'.format(file_path_or_dict))
|
||||||
self.filename = Path(file_path_or_dict).name
|
self.filename = Path(file_path_or_dict).name
|
||||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
if fast_mode:
|
||||||
|
if os.path.isfile(file_path_or_dict):
|
||||||
|
raise ValueError('fast_mode={:} must feed the path for directory : {:}'.format(fast_mode, file_path_or_dict))
|
||||||
|
else:
|
||||||
|
self._archive_dir = file_path_or_dict
|
||||||
|
else:
|
||||||
|
if os.path.isdir(file_path_or_dict):
|
||||||
|
raise ValueError('fast_mode={:} must feed the path for file : {:}'.format(fast_mode, file_path_or_dict))
|
||||||
|
else:
|
||||||
|
file_path_or_dict = pickle_load(file_path_or_dict)
|
||||||
elif isinstance(file_path_or_dict, dict):
|
elif isinstance(file_path_or_dict, dict):
|
||||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
self.verbose = verbose
|
||||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
if isinstance(file_path_or_dict, dict):
|
||||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
self.meta_archs = copy.deepcopy(file_path_or_dict['meta_archs'])
|
||||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
self.arch2infos_dict = OrderedDict()
|
||||||
self.arch2infos_dict = OrderedDict()
|
self._avaliable_hps = set()
|
||||||
self._avaliable_hps = set()
|
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
all_infos = file_path_or_dict['arch2infos'][xkey]
|
||||||
all_infos = file_path_or_dict['arch2infos'][xkey]
|
hp2archres = OrderedDict()
|
||||||
hp2archres = OrderedDict()
|
for hp_key, results in all_infos.items():
|
||||||
for hp_key, results in all_infos.items():
|
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
||||||
self._avaliable_hps.add(hp_key) # save the avaliable hyper-parameter
|
self.arch2infos_dict[xkey] = hp2archres
|
||||||
self.arch2infos_dict[xkey] = hp2archres
|
self.evaluated_indexes = set(file_path_or_dict['evaluated_indexes'])
|
||||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
elif self.archive_dir is not None:
|
||||||
|
benchmark_meta = pickle_load('{:}/meta.{:}'.format(self.archive_dir, PICKLE_EXT))
|
||||||
|
self.meta_archs = copy.deepcopy(benchmark_meta['meta_archs'])
|
||||||
|
self.arch2infos_dict = OrderedDict()
|
||||||
|
self._avaliable_hps = set()
|
||||||
|
self.evaluated_indexes = set()
|
||||||
|
else:
|
||||||
|
raise ValueError('file_path_or_dict [{:}] must be a dict or archive_dir must be set'.format(type(file_path_or_dict)))
|
||||||
self.archstr2index = {}
|
self.archstr2index = {}
|
||||||
for idx, arch in enumerate(self.meta_archs):
|
for idx, arch in enumerate(self.meta_archs):
|
||||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||||
self.archstr2index[ arch ] = idx
|
self.archstr2index[arch] = idx
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
|
print('Create NATS-Bench (size) done with {:}/{:} architectures avaliable.'.format(len(self.evaluated_indexes), len(self.meta_archs)))
|
||||||
|
|
||||||
@ -100,7 +125,7 @@ class NATSsize(NASBenchMetaAPI):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
|
print('Call clear_params with archive_root={:} and index={:}'.format(archive_root, index))
|
||||||
if archive_root is None:
|
if archive_root is None:
|
||||||
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
|
archive_root = os.path.join(os.environ['TORCH_HOME'], '{:}-full'.format(ALL_BASE_NAMES[-1]))
|
||||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||||
if index is None:
|
if index is None:
|
||||||
indexes = list(range(len(self)))
|
indexes = list(range(len(self)))
|
||||||
@ -108,16 +133,17 @@ class NATSsize(NASBenchMetaAPI):
|
|||||||
indexes = [index]
|
indexes = [index]
|
||||||
for idx in indexes:
|
for idx in indexes:
|
||||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
|
xfile_path = os.path.join(archive_root, '{:06d}.{:}'.format(idx, PICKLE_EXT))
|
||||||
if not os.path.isfile(xfile_path):
|
if not os.path.isfile(xfile_path):
|
||||||
xfile_path = os.path.join(archive_root, '{:d}-FULL.pth'.format(idx))
|
xfile_path = os.path.join(archive_root, '{:d}.{:}'.format(idx, PICKLE_EXT))
|
||||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||||
xdata = torch.load(xfile_path, map_location='cpu')
|
xdata = pickle_load(xfile_path)
|
||||||
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
|
assert isinstance(xdata, dict), 'invalid format of data in {:}'.format(xfile_path)
|
||||||
|
self.evaluated_indexes.add(idx)
|
||||||
hp2archres = OrderedDict()
|
hp2archres = OrderedDict()
|
||||||
for hp_key, results in xdata.items():
|
for hp_key, results in xdata.items():
|
||||||
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
hp2archres[hp_key] = ArchResults.create_from_state_dict(results)
|
||||||
|
self._avaliable_hps.add(hp_key)
|
||||||
self.arch2infos_dict[idx] = hp2archres
|
self.arch2infos_dict[idx] = hp2archres
|
||||||
|
|
||||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||||
@ -153,6 +179,7 @@ class NATSsize(NASBenchMetaAPI):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||||
|
self._prepare_info(index)
|
||||||
if index not in self.arch2infos_dict:
|
if index not in self.arch2infos_dict:
|
||||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||||
archresult = self.arch2infos_dict[index][str(hp)]
|
archresult = self.arch2infos_dict[index][str(hp)]
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
############################################################################################
|
############################################################################################
|
||||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
||||||
############################################################################################
|
############################################################################################
|
||||||
import os, copy, random, torch, numpy as np
|
import os, copy, random, numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Text, Union, Dict, Optional
|
from typing import List, Text, Union, Dict, Optional
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
@ -62,7 +62,7 @@ class NATStopology(NASBenchMetaAPI):
|
|||||||
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NATS-Bench (topology) api from {:}'.format(file_path_or_dict))
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||||
self.filename = Path(file_path_or_dict).name
|
self.filename = Path(file_path_or_dict).name
|
||||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
file_path_or_dict = np.load(file_path_or_dict)
|
||||||
elif isinstance(file_path_or_dict, dict):
|
elif isinstance(file_path_or_dict, dict):
|
||||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||||
|
@ -10,15 +10,30 @@
|
|||||||
# History:
|
# History:
|
||||||
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
# [2020.07.31] The first version, where most content reused nas_201_api/api_utils.py
|
||||||
#
|
#
|
||||||
import abc, copy, random, numpy as np
|
import os, abc, copy, random, numpy as np
|
||||||
|
import bz2, pickle
|
||||||
import importlib, warnings
|
import importlib, warnings
|
||||||
from typing import List, Text, Union, Dict, Optional
|
from typing import List, Text, Union, Dict, Optional
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
USE_TORCH = importlib.find_loader('torch') is not None
|
|
||||||
if USE_TORCH:
|
|
||||||
import torch
|
def pickle_save(obj, file_path, ext='.pbz2', protocol=4):
|
||||||
else:
|
"""Use pickle to save data (obj) into file_path.
|
||||||
warnings.warn('Can not find PyTorch, and thus some features maybe invalid.')
|
According to https://docs.python.org/3/library/pickle.html#data-stream-format, Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. It is the default protocol starting with Python 3.8.
|
||||||
|
"""
|
||||||
|
# with open(file_path, 'wb') as cfile:
|
||||||
|
with bz2.BZ2File(str(file_path) + ext, 'wb') as cfile:
|
||||||
|
pickle.dump(obj, cfile, protocol=protocol)
|
||||||
|
|
||||||
|
|
||||||
|
def pickle_load(file_path, ext='.pbz2'):
|
||||||
|
# return pickle.load(open(file_path, "rb"))
|
||||||
|
if os.path.isfile(str(file_path)):
|
||||||
|
xfile_path = str(file_path)
|
||||||
|
else:
|
||||||
|
xfile_path = str(file_path) + ext
|
||||||
|
with bz2.BZ2File(xfile_path, 'rb') as cfile:
|
||||||
|
return pickle.load(cfile)
|
||||||
|
|
||||||
|
|
||||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||||
@ -60,7 +75,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
return len(self.meta_archs)
|
return len(self.meta_archs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
return ('{name}({num}/{total} architectures, fast_mode={fast_mode}, file={filename})'.format(
|
||||||
|
name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs),
|
||||||
|
fast_mode=self.fast_mode, filename=self.filename))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def avaliable_hps(self):
|
def avaliable_hps(self):
|
||||||
@ -74,6 +91,20 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
def search_space_name(self):
|
def search_space_name(self):
|
||||||
return self._search_space_name
|
return self._search_space_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fast_mode(self):
|
||||||
|
return self._fast_mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def archive_dir(self):
|
||||||
|
return self._archive_dir
|
||||||
|
|
||||||
|
def reset_archive_dir(self, archive_dir):
|
||||||
|
self._archive_dir = archive_dir
|
||||||
|
|
||||||
|
def reset_fast_mode(self, fast_mode):
|
||||||
|
self._fast_mode = fast_mode
|
||||||
|
|
||||||
def reset_time(self):
|
def reset_time(self):
|
||||||
self._used_time = 0
|
self._used_time = 0
|
||||||
|
|
||||||
@ -121,9 +152,24 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
return arch_index
|
return arch_index
|
||||||
|
|
||||||
def query_by_arch(self, arch, hp):
|
def query_by_arch(self, arch, hp):
|
||||||
# This is to make the current version be compatible with the old version.
|
"""This is to make the current version be compatible with the old version."""
|
||||||
return self.query_info_str_by_arch(arch, hp)
|
return self.query_info_str_by_arch(arch, hp)
|
||||||
|
|
||||||
|
def _prepare_info(self, index):
|
||||||
|
"""This is a function to load the data from disk when using fast mode."""
|
||||||
|
if index not in self.arch2infos_dict:
|
||||||
|
if self.fast_mode and self.archive_dir is not None:
|
||||||
|
self.reload(self.archive_dir, index)
|
||||||
|
elif not self.fast_mode:
|
||||||
|
if self.verbose:
|
||||||
|
print('Call _prepare_info with index={:} skip because it is not the fast mode.'.format(index))
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid status: fast_mode={:} and archive_dir={:}'.format(self.fast_mode, self.archive_dir))
|
||||||
|
else:
|
||||||
|
assert index in self.evaluated_indexes, 'The index of {:} is not in self.evaluated_indexes, there must be something wrong.'.format(index)
|
||||||
|
if self.verbose:
|
||||||
|
print('Call _prepare_info with index={:} skip because it is in arch2infos_dict'.format(index))
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def reload(self, archive_root: Text = None, index: int = None):
|
def reload(self, archive_root: Text = None, index: int = None):
|
||||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||||
@ -140,7 +186,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
||||||
if hp is None:
|
if index not in self.arch2infos_dict:
|
||||||
|
warnings.warn('The {:}-th architecture is not in the benchmark data yet, no need to clear params.'.format(index))
|
||||||
|
elif hp is None:
|
||||||
for key, result in self.arch2infos_dict[index].items():
|
for key, result in self.arch2infos_dict[index].items():
|
||||||
result.clear_params()
|
result.clear_params()
|
||||||
else:
|
else:
|
||||||
@ -154,6 +202,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
||||||
arch_index = self.query_index_by_arch(arch)
|
arch_index = self.query_index_by_arch(arch)
|
||||||
|
self._prepare_info(arch_index)
|
||||||
if arch_index in self.arch2infos_dict:
|
if arch_index in self.arch2infos_dict:
|
||||||
if hp not in self.arch2infos_dict[arch_index]:
|
if hp not in self.arch2infos_dict[arch_index]:
|
||||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||||
@ -161,13 +210,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
||||||
return '\n'.join(strings)
|
return '\n'.join(strings)
|
||||||
else:
|
else:
|
||||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
warnings.warn('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
||||||
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
||||||
|
self._prepare_info(arch_index)
|
||||||
if arch_index in self.arch2infos_dict:
|
if arch_index in self.arch2infos_dict:
|
||||||
if hp not in self.arch2infos_dict[arch_index]:
|
if hp not in self.arch2infos_dict[arch_index]:
|
||||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||||
@ -207,7 +257,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
||||||
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
||||||
best_index, highest_accuracy = -1, None
|
best_index, highest_accuracy = -1, None
|
||||||
for i, arch_index in enumerate(self.evaluated_indexes):
|
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||||
|
for i, arch_index in enumerate(evaluated_indexes):
|
||||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||||
@ -254,10 +305,11 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
||||||
|
self._prepare_info(index)
|
||||||
if index in self.arch2infos_dict:
|
if index in self.arch2infos_dict:
|
||||||
info = self.arch2infos_dict[index]
|
info = self.arch2infos_dict[index]
|
||||||
else:
|
else:
|
||||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
|
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(index))
|
||||||
info = next(iter(info.values()))
|
info = next(iter(info.values()))
|
||||||
results = info.query(dataset, None)
|
results = info.query(dataset, None)
|
||||||
results = next(iter(results.values()))
|
results = next(iter(results.values()))
|
||||||
@ -267,6 +319,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||||
|
self._prepare_info(index)
|
||||||
info = self.query_meta_info_by_index(index, hp)
|
info = self.query_meta_info_by_index(index, hp)
|
||||||
return info.get_compute_costs(dataset)
|
return info.get_compute_costs(dataset)
|
||||||
|
|
||||||
@ -296,8 +349,9 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
if index < 0: # show all architectures
|
if index < 0: # show all architectures
|
||||||
print(self)
|
print(self)
|
||||||
for i, idx in enumerate(self.evaluated_indexes):
|
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
for i, idx in enumerate(evaluated_indexes):
|
||||||
|
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(evaluated_indexes), idx) + '-'*10)
|
||||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||||
for key, result in self.arch2infos_dict[index].items():
|
for key, result in self.arch2infos_dict[index].items():
|
||||||
strings = print_information(result)
|
strings = print_information(result)
|
||||||
@ -325,7 +379,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
if dataset not in valid_datasets:
|
if dataset not in valid_datasets:
|
||||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||||
nums, hp = defaultdict(lambda: 0), str(hp)
|
nums, hp = defaultdict(lambda: 0), str(hp)
|
||||||
for index in range(len(self)):
|
# for index in range(len(self)):
|
||||||
|
for index in self.evaluated_indexes:
|
||||||
archInfo = self.arch2infos_dict[index][hp]
|
archInfo = self.arch2infos_dict[index][hp]
|
||||||
dataset_seed = archInfo.dataset_seed
|
dataset_seed = archInfo.dataset_seed
|
||||||
if dataset not in dataset_seed:
|
if dataset not in dataset_seed:
|
||||||
@ -550,9 +605,7 @@ class ArchResults(object):
|
|||||||
def create_from_state_dict(state_dict_or_file):
|
def create_from_state_dict(state_dict_or_file):
|
||||||
x = ArchResults(-1, -1)
|
x = ArchResults(-1, -1)
|
||||||
if isinstance(state_dict_or_file, str): # a file path
|
if isinstance(state_dict_or_file, str): # a file path
|
||||||
if not USE_TORCH:
|
state_dict = pickle_load(state_dict_or_file)
|
||||||
raise ValueError('Since torch is not imported, this logic can not be used.')
|
|
||||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
|
||||||
elif isinstance(state_dict_or_file, dict):
|
elif isinstance(state_dict_or_file, dict):
|
||||||
state_dict = state_dict_or_file
|
state_dict = state_dict_or_file
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user