update NAS-Bench

This commit is contained in:
D-X-Y 2020-03-09 19:38:00 +11:00
parent 9a83814a46
commit e59eb804cb
35 changed files with 693 additions and 64 deletions

View File

@ -25,7 +25,6 @@ This project implemented several neural architecture search (NAS) and hyper-para
At the moment, this project provides the following algorithms and scripts to run them. Please see the details in the link provided in the description column. At the moment, this project provides the following algorithms and scripts to run them. Please see the details in the link provided in the description column.
<table> <table>
<tbody> <tbody>
<tr align="center" valign="bottom"> <tr align="center" valign="bottom">

View File

@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "12"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

View File

@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "90"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

View File

@ -39,10 +39,10 @@ If you are interested in the configs of each NAS-searched architecture, they are
### Searching on the NASNet search space ### Searching on the NASNet search space
Please use the following scripts to use GDAS to search as in the original paper: Please use the following scripts to use GDAS to search as in the original paper:
``` ```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-GDAS.sh cifar10 1 -1
``` ```
**After searching***, if you want to re-train the searched architecture found by the above script, you can use the following script: **After searching**, if you want to re-train the searched architecture found by the above script, you can use the following script:
``` ```
CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \ CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \
output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1 output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1

View File

@ -30,6 +30,13 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1 CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
``` ```
### Searching on the NASNet search space
Please use the following scripts to use SETN to search as in the original paper:
```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NASNet-space-search-by-SETN.sh cifar10 1 -1
```
### Searching on the NAS-Bench-201 search space
The searching codes of SETN on a small search space (NAS-Bench-201). The searching codes of SETN on a small search space (NAS-Bench-201).
``` ```
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1

View File

@ -146,6 +146,10 @@ api.get_more_info(112, 'cifar10', None, False, True)
api.get_more_info(112, 'ImageNet16-120', None, False, True) # the info of last training epoch for 112-th architecture (use 200-epoch-hyper-parameter and randomly select a trial) api.get_more_info(112, 'ImageNet16-120', None, False, True) # the info of last training epoch for 112-th architecture (use 200-epoch-hyper-parameter and randomly select a trial)
``` ```
Please use the following script to show the best architectures on each dataset:
```show the best architecture
python exps/NAS-Bench-201/show-best.py
```
## Instruction to Re-Generate NAS-Bench-201 ## Instruction to Re-Generate NAS-Bench-201

View File

@ -3,10 +3,8 @@
##################################################### #####################################################
# python exps/NAS-Bench-201/check.py --base_save_dir # python exps/NAS-Bench-201/check.py --base_save_dir
##################################################### #####################################################
import os, sys, time, argparse, collections import sys, time, argparse, collections
from shutil import copyfile
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()

View File

@ -0,0 +1,39 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
################################################################################################
# python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth #
################################################################################################
import os, sys, time, glob, random, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
args = parser.parse_args()
meta_file = Path(args.api_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
api = API(str(meta_file))
# This will show the results of the best architecture based on the validation set of each dataset.
arch_index, accuracy = api.find_best('cifar10-valid', 'x-valid', None, None, False)
print('FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::')
print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index)))
api.show(arch_index)
print('')
arch_index, accuracy = api.find_best('cifar100', 'x-valid', None, None, False)
print('FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::')
print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index)))
api.show(arch_index)
print('')
arch_index, accuracy = api.find_best('ImageNet16-120', 'x-valid', None, None, False)
print('FOR ImageNet16-120, using the hyper-parameters with 200 training epochs :::')
print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index)))
api.show(arch_index)
print('')

View File

@ -0,0 +1,196 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###############################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text],
splits: List[Text], config_path: Text, seed: int, workers: int, logger):
machine_info = get_machine_info()
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
# load the configurature
if dataset == 'cifar10' or dataset == 'cifar100':
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
elif dataset.startswith('ImageNet16'):
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger)
# check whether use splited validation set
if bool(split):
assert dataset == 'cifar10'
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
ValLoaders['x-valid'] = valid_loader
else:
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
if dataset == 'cifar10':
ValLoaders = {'ori-test': valid_loader}
elif dataset == 'cifar100':
cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
}
elif dataset == 'ImageNet16-120':
imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
}
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
dataset_key = '{:}'.format(dataset)
if bool(split): dataset_key = dataset_key + '-valid'
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
for key, value in ValLoaders.items():
logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
# arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
# this genotype is the architecture with the highest accuracy on CIFAR-100 validation set
genotype = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|'
arch_config = dict2config(dict(name='infer.shape.tiny', channels=channels, genotype=genotype, num_classes=class_num), None)
results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append( dataset_key )
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any],
srange: tuple, cover_mode: bool):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(workers)
log_dir = save_dir / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(str(log_dir), 0, False)
logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
logger.log('-' * 100)
logger.log(
'Start evaluating range =: {:06d} - {:06d} / {:06d} with cover-mode={:}'.format(srange[0], srange[1], len(nets),
cover_mode))
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
logger.log(
'--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
logger.log('--->>> optimization config : {:}'.format(opt_config))
to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
start_time, epoch_time = time.time(), AverageMeter()
for i, index in enumerate(to_evaluate_indexes):
channelstr = nets[index]
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i,
len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15))
logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15))
# test this arch on different datasets with different seeds
has_continue = False
for seed in seeds:
to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
if to_save_name.exists():
if cover_mode:
logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
os.remove(str(to_save_name))
else:
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
has_continue = True
continue
results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
torch.save(results, to_save_name)
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i,
len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
# measure elapsed time
if not has_continue: epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True))
logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True)))
logger.log('{:}'.format('*' * 100))
logger.log('{:} {:74s} {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(
to_evaluate_indexes), index, len(nets), need_time), '*' * 10))
logger.log('{:}'.format('*' * 100))
logger.close()
def traverse_net(candidates: List[int], N: int):
nets = ['']
for i in range(N):
new_nets = []
for net in nets:
for C in candidates:
new_nets.append(str(C) if net == '' else "{:}:{:}".format(net,C))
nets = new_nets
return nets
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode', type=str, required=True, choices=['new', 'cover'], help='The script mode.')
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
parser.add_argument('--candidateC', type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.')
parser.add_argument('--num_layers', type=int, default=5, help='The number of layers in a network.')
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
# use for train the model
parser.add_argument('--workers', type=int, default=8, help='The number of data loading workers (default: 2)')
parser.add_argument('--srange' , type=str, required=True, help='The range of models to be evaluated')
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
parser.add_argument('--hyper', type=str, default='12', choices=['12', '90'], help='The tag for hyper-parameters.')
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
args = parser.parse_args()
nets = traverse_net(args.candidateC, args.num_layers)
if len(nets) != args.check_N: raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N))
opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper)
if not os.path.isfile(opt_config): raise ValueError('{:} is not a file.'.format(opt_config))
save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper)
save_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(args.srange, str) or len(args.srange.split('-')) != 2:
raise ValueError('Invalid scheme for {:}'.format(args.srange))
srange = args.srange.split('-')
srange = (int(srange[0]), int(srange[1]))
assert 0 <= srange[0] <= srange[1] < args.check_N, '{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N)
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config,
srange, args.mode == 'cover')

View File

@ -3,11 +3,9 @@
########################################################################### ###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
########################################################################### ###########################################################################
import os, sys, time, random, argparse import sys, time, random, argparse
import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
import torch.nn as nn
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
@ -107,7 +105,6 @@ def main(xargs):
logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion)) logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape) flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space))
if xargs.arch_nas_dataset is None: if xargs.arch_nas_dataset is None:

View File

@ -3,7 +3,7 @@
###################################################################################### ######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
###################################################################################### ######################################################################################
import os, sys, time, glob, random, argparse import sys, time, random, argparse
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
import torch import torch
@ -93,8 +93,7 @@ def get_best_arch(xloader, network, n_samples):
_, logits = network(inputs) _, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
valid_accs.append( val_top1.item() ) valid_accs.append(val_top1.item())
#print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))
best_idx = np.argmax(valid_accs) best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
@ -142,10 +141,13 @@ def main(xargs):
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name) search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, if xargs.model_config is None:
'max_nodes': xargs.max_nodes, 'num_classes': class_num, model_config = dict2config(
'space' : search_space, dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None)
else:
model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False,
track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space)) logger.log('search space : {:}'.format(search_space))
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
@ -156,7 +158,6 @@ def main(xargs):
logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion)) logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape) flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space)) logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None: if xargs.arch_nas_dataset is None:
@ -233,7 +234,7 @@ def main(xargs):
'last_checkpoint': save_path, 'last_checkpoint': save_path,
}, logger.path('info'), logger) }, logger.path('info'), logger)
with torch.no_grad(): with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)

View File

@ -1,4 +1,4 @@
import os, sys, time, random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_attention_args(): def obtain_attention_args():

View File

@ -1,7 +1,7 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
################################################## ##################################################
import os, sys, time, random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_basic_args(): def obtain_basic_args():

View File

@ -1,4 +1,4 @@
import os, sys, time, random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_cls_init_args(): def obtain_cls_init_args():

View File

@ -1,4 +1,4 @@
import os, sys, time, random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_cls_kd_args(): def obtain_cls_kd_args():

View File

@ -4,7 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# #
import os, sys, json import os, json
from os import path as osp from os import path as osp
from pathlib import Path from pathlib import Path
from collections import namedtuple from collections import namedtuple

View File

@ -39,6 +39,13 @@ def get_cell_based_tiny_net(config):
genotype = CellStructure.str2structure(config.arch_str) genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes) return TinyNetwork(config.C, config.N, genotype, config.num_classes)
elif config.name == 'infer.shape.tiny':
from .shape_infers import DynamicShapeTinyNet
if isinstance(config.channels, str):
channels = tuple([int(x) for x in config.channels.split(':')])
else: channels = config.channels
genotype = CellStructure.str2structure(config.genotype)
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
elif config.name == 'infer.nasnet-cifar': elif config.name == 'infer.nasnet-cifar':
from .cell_infers import NASNetonCIFAR from .cell_infers import NASNetonCIFAR
raise NotImplementedError raise NotImplementedError

View File

@ -1,7 +1,6 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
##################################################### #####################################################
import torch
import torch.nn as nn import torch.nn as nn
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .cells import InferCell from .cells import InferCell

View File

@ -172,14 +172,19 @@ class FactorizedReduce(nn.Module):
for i in range(2): for i in range(2):
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else: else:
raise ValueError('Invalid stride : {:}'.format(stride)) raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x): def forward(self, x):
x = self.relu(x) if self.stride == 2:
y = self.pad(x) x = self.relu(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out) out = self.bn(out)
return out return out

View File

@ -14,11 +14,11 @@ from .search_model_darts_nasnet import NASNetworkDARTS
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts, nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
'DARTS-V2': TinyNetworkDarts, "DARTS-V2": TinyNetworkDarts,
'GDAS' : TinyNetworkGDAS, "GDAS": TinyNetworkGDAS,
'SETN' : TinyNetworkSETN, "SETN": TinyNetworkSETN,
'ENAS' : TinyNetworkENAS, "ENAS": TinyNetworkENAS,
'RANDOM' : TinyNetworkRANDOM} "RANDOM": TinyNetworkRANDOM}
nasnet_super_nets = {'GDAS' : NASNetworkGDAS, nasnet_super_nets = {"GDAS": NASNetworkGDAS,
'DARTS': NASNetworkDARTS} "DARTS": NASNetworkDARTS}

View File

@ -1,5 +1,5 @@
#################### ####################
# DARTS, ICLR 2019 # # DARTS, ICLR 2019 #
#################### ####################
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module): class NASNetworkDARTS(nn.Module):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool): def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkDARTS, self).__init__() super(NASNetworkDARTS, self).__init__()
self._C = C self._C = C
self._layerN = N self._layerN = N

View File

@ -6,14 +6,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
# The macro structure is based on NASNet # The macro structure is based on NASNet
class NASNetworkSETN(nn.Module): class NASNetworkSETN(nn.Module):
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats): def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkSETN, self).__init__() super(NASNetworkSETN, self).__init__()
self._C = C self._C = C
self._layerN = N self._layerN = N
@ -45,6 +46,16 @@ class NASNetworkSETN(nn.Module):
self.classifier = nn.Linear(C_prev, num_classes) self.classifier = nn.Linear(C_prev, num_classes)
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.mode = 'urs'
self.dynamic_cell = None
def set_cal_mode(self, mode, dynamic_cell=None):
assert mode in ['urs', 'joint', 'select', 'dynamic']
self.mode = mode
if mode == 'dynamic':
self.dynamic_cell = deepcopy(dynamic_cell)
else:
self.dynamic_cell = None
def get_weights(self): def get_weights(self):
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
@ -70,6 +81,24 @@ class NASNetworkSETN(nn.Module):
def extra_repr(self): def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def dync_genotype(self, use_random=False):
genotypes = []
with torch.no_grad():
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if use_random:
op_name = random.choice(self.op_names)
else:
weights = alphas_cpu[ self.edge2index[node_str] ]
op_index = torch.multinomial(weights, 1).item()
op_name = self.op_names[ op_index ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure( genotypes )
def genotype(self): def genotype(self):
def _parse(weights): def _parse(weights):
gene = [] gene = []
@ -94,9 +123,6 @@ class NASNetworkSETN(nn.Module):
def forward(self, inputs): def forward(self, inputs):
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1) normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1) reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
with torch.no_grad():
normal_hardwts_cpu = normal_hardwts.detach().cpu()
reduce_hardwts_cpu = reduce_hardwts.detach().cpu()
s0 = s1 = self.stem(inputs) s0 = s1 = self.stem(inputs)
for i, cell in enumerate(self.cells): for i, cell in enumerate(self.cells):

View File

@ -1,8 +1,9 @@
import math, torch #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):

View File

@ -1,8 +1,9 @@
import math, torch #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):

View File

@ -1,8 +1,9 @@
import math #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):

View File

@ -1,8 +1,9 @@
import math, torch #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):

View File

@ -1,7 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018 # MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn from torch import nn
from ..initialization import initialize_resnet from ..initialization import initialize_resnet
from ..SharedUtils import additive_func, parse_channel_info from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module): class ConvBNReLU(nn.Module):

View File

@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -1,5 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2 from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@ -7,7 +7,8 @@
# [2020.03.08] Next version (coming soon) # [2020.03.08] Next version (coming soon)
# #
# #
import os, sys, copy, random, torch, numpy as np import os, copy, random, torch, numpy as np
from typing import List, Text, Union, Dict, Any
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201.
class NASBench201API(object): class NASBench201API(object):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """ """ The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict, verbose=True): def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
if isinstance(file_path_or_dict, str): if isinstance(file_path_or_dict, str):
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict)) if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
@ -69,7 +70,7 @@ class NASBench201API(object):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx self.archstr2index[ arch ] = idx
def __getitem__(self, index): def __getitem__(self, index: int):
return copy.deepcopy( self.meta_archs[index] ) return copy.deepcopy( self.meta_archs[index] )
def __len__(self): def __len__(self):
@ -99,7 +100,7 @@ class NASBench201API(object):
# Overwrite all information of the 'index'-th architecture in the search space. # Overwrite all information of the 'index'-th architecture in the search space.
# It will load its data from 'archive_root'. # It will load its data from 'archive_root'.
def reload(self, archive_root, index): def reload(self, archive_root: Text, index: int):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index) assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
@ -141,7 +142,8 @@ class NASBench201API(object):
# -- cifar10 : training the model on the CIFAR-10 training + validation set. # -- cifar10 : training the model on the CIFAR-10 training + validation set.
# -- cifar100 : training the model on the CIFAR-100 training set. # -- cifar100 : training the model on the CIFAR-100 training set.
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set. # -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
use_12epochs_result: bool = False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr) assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
@ -177,7 +179,7 @@ class NASBench201API(object):
return best_index, highest_accuracy return best_index, highest_accuracy
# return the topology structure of the `index`-th architecture # return the topology structure of the `index`-th architecture
def arch(self, index): def arch(self, index: int):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index]) return copy.deepcopy(self.meta_archs[index])
@ -238,7 +240,7 @@ class NASBench201API(object):
# `is_random` # `is_random`
# When is_random=True, the performance of a random architecture will be returned # When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged. # When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index] archresult = arch2infos[index]
@ -301,7 +303,7 @@ class NASBench201API(object):
If the index < 0: it will loop for all architectures and print their information one by one. If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture. else: it will print the information of the 'index'-th archiitecture.
""" """
def show(self, index=-1): def show(self, index: int = -1) -> None:
if index < 0: # show all architectures if index < 0: # show all architectures
print(self) print(self)
for i, idx in enumerate(self.evaluated_indexes): for i, idx in enumerate(self.evaluated_indexes):
@ -336,8 +338,8 @@ class NASBench201API(object):
# for i, node in enumerate(arch): # for i, node in enumerate(arch):
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node)) # print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
@staticmethod @staticmethod
def str2lists(xstr): def str2lists(xstr: Text) -> List[Any]:
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) # assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+') nodestrs = xstr.split('+')
genotypes = [] genotypes = []
for i, node_str in enumerate(nodestrs): for i, node_str in enumerate(nodestrs):

View File

@ -3,6 +3,8 @@
################################################## ##################################################
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler from .optimizers import get_optim_scheduler
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
def get_procedures(procedure): def get_procedures(procedure):
from .basic_main import basic_train, basic_valid from .basic_main import basic_train, basic_valid

View File

@ -0,0 +1,129 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import time, torch
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append( batch_time.val - data_time.val )
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2: latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train' : network.train()
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config)
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, opt_config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times, lrs = {}, {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
lr = min(scheduler.get_lr())
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times [epoch] = train_tm
lrs[epoch] = lr
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr))
info_seed = {'flop' : flop,
'param': param,
'arch_config' : arch_config._asdict(),
'opt_config' : opt_config._asdict(),
'total_epoch' : total_epoch ,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times' : train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times' : valid_times,
'learning_rates': lrs,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True
}
return info_seed

View File

@ -0,0 +1,38 @@
#!/bin/bash
# bash ./scripts-search/NASNet-space-search-by-GDAS.sh cifar10 1 -1
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 3 parameters for dataset, track_running_stats, and seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=$1
track_running_stats=$2
seed=$3
space=darts
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${track_running_stats}
OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
--save_dir ${save_dir} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--config_path configs/search-opts/GDAS-NASNet-CIFAR.config \
--model_config configs/search-archs/GDAS-NASNet-CIFAR.config \
--tau_max 10 --tau_min 0.1 --track_running_stats ${track_running_stats} \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -0,0 +1,40 @@
#!/bin/bash
# bash ./scripts-search/NASNet-space-search-by-SETN.sh cifar10 1 -1
# [TO BE DONE]
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 3 parameters for dataset, track_running_stats, and seed"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
dataset=$1
track_running_stats=$2
seed=$3
space=darts
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python"
else
data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi
save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${track_running_stats}
OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
--save_dir ${save_dir} \
--dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \
--config_path configs/search-opts/SETN-NASNet-CIFAR.config \
--model_config configs/search-archs/SETN-NASNet-CIFAR.config \
--track_running_stats ${track_running_stats} \
--select_num 1000 \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -0,0 +1,44 @@
#!/bin/bash
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
#####################################################
# [mars6] CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/X-X/train-shapes.sh 00000-05000 12 777
# [mars6] bash ./scripts-search/X-X/train-shapes.sh 05001-10000 12 777
# [mars20] bash ./scripts-search/X-X/train-shapes.sh 10001-14500 12 777
# [mars20] bash ./scripts-search/X-X/train-shapes.sh 14501-19500 12 777
# bash ./scripts-search/X-X/train-shapes.sh 19501-23500 12 777
# bash ./scripts-search/X-X/train-shapes.sh 23501-27500 12 777
# bash ./scripts-search/X-X/train-shapes.sh 27501-30000 12 777
# bash ./scripts-search/X-X/train-shapes.sh 30001-32767 12 777
#
echo script name: $0
echo $# arguments
if [ "$#" -ne 3 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 3 parameters for start-and-end, hyper-parameters-opt-file, and seeds"
exit 1
fi
if [ "$TORCH_HOME" = "" ]; then
echo "Must set TORCH_HOME envoriment variable for data dir saving"
exit 1
else
echo "TORCH_HOME : $TORCH_HOME"
fi
srange=$1
opt=$2
all_seeds=$3
cpus=4
save_dir=./output/NAS-BENCH-202/
OMP_NUM_THREADS=${cpus} python exps/NAS-Bench-201/xshapes.py \
--mode new --srange ${srange} --hyper ${opt} --save_dir ${save_dir} \
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
--splits 1 0 0 0 \
--xpaths $TORCH_HOME/cifar.python \
$TORCH_HOME/cifar.python \
$TORCH_HOME/cifar.python \
$TORCH_HOME/cifar.python/ImageNet16 \
--workers ${cpus} \
--seeds ${all_seeds}