diff --git a/NAS-Bench-102.md b/NAS-Bench-102.md index 2b8112b..fd82be8 100644 --- a/NAS-Bench-102.md +++ b/NAS-Bench-102.md @@ -158,6 +158,8 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_ We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102. If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly. +Note that you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download) + - [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 64037e9..6173748 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -135,6 +135,7 @@ def main(xargs): 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) search_model = get_cell_based_tiny_net(model_config) + logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) @@ -211,10 +212,9 @@ def main(xargs): if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) - if api is not None: - logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index 1ee1215..667de61 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces +from nas_102_api import NASBench102API as API def _concat(xs): @@ -198,6 +199,7 @@ def main(xargs): 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) search_model = get_cell_based_tiny_net(model_config) + logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) @@ -208,6 +210,11 @@ def main(xargs): flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() @@ -229,7 +236,7 @@ def main(xargs): start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup + start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) @@ -238,7 +245,8 @@ def main(xargs): logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) - logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) + search_time.update(time.time() - start_time) + logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy @@ -271,29 +279,15 @@ def main(xargs): copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) # check the performance from the architecture dataset - """ - if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - else: - nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) - geno = genotypes[total_epoch-1] - logger.log('The last model is {:}'.format(geno)) - info = nas_bench.query_by_arch( geno ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) - geno = genotypes['best'] - logger.log('The best model is {:}'.format(geno)) - info = nas_bench.query_by_arch( geno ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) - else : logger.log('{:}'.format(info)) - """ + logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) + if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) logger.close() diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index 50c55e1..4e32cab 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -1,6 +1,8 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## +########################################################################### +# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # +########################################################################### import os, sys, time, glob, random, argparse import numpy as np from copy import deepcopy @@ -15,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_cell_based_tiny_net, get_search_spaces +from nas_102_api import NASBench102API as API def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): @@ -103,6 +106,7 @@ def main(xargs): 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) search_model = get_cell_based_tiny_net(model_config) + logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) @@ -113,7 +117,12 @@ def main(xargs): flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) - logger.log('search_space : {:}'.format(search_space)) + logger.log('search-space : {:}'.format(search_space)) + if xargs.arch_nas_dataset is None: + api = None + else: + api = API(xargs.arch_nas_dataset) + logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() @@ -135,7 +144,7 @@ def main(xargs): start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training - start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup + start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) @@ -145,6 +154,7 @@ def main(xargs): search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) + search_time.update(time.time() - start_time) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) # check the best accuracy @@ -177,24 +187,15 @@ def main(xargs): copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) # check the performance from the architecture dataset - """ - if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - else: - nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) - geno = genotypes[total_epoch-1] - logger.log('The last model is {:}'.format(geno)) - info = nas_bench.query_by_arch( geno ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) - else : logger.log('{:}'.format(info)) - logger.log('-'*100) - """ + logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) + if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) logger.close() diff --git a/lib/models/cell_infers/cells.py b/lib/models/cell_infers/cells.py index 2071d5c..ae26a79 100644 --- a/lib/models/cell_infers/cells.py +++ b/lib/models/cell_infers/cells.py @@ -19,9 +19,9 @@ class InferCell(nn.Module): cur_innod = [] for (op_name, op_in) in node_info: if op_in == 0: - layer = OPS[op_name](C_in , C_out, stride, True) + layer = OPS[op_name](C_in , C_out, stride, True, True) else: - layer = OPS[op_name](C_out, C_out, 1, True) + layer = OPS[op_name](C_out, C_out, 1, True, True) cur_index.append( len(self.layers) ) cur_innod.append( op_in ) self.layers.append( layer ) diff --git a/lib/models/cell_operations.py b/lib/models/cell_operations.py index 5e2b779..a454bea 100644 --- a/lib/models/cell_operations.py +++ b/lib/models/cell_operations.py @@ -7,13 +7,13 @@ import torch.nn as nn __all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] OPS = { - 'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), - 'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), - 'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), - 'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine), - 'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine), - 'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine), - 'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), + 'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), + 'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), + 'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), + 'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), + 'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), + 'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), + 'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), } CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] @@ -27,12 +27,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK, class ReLUConvBN(nn.Module): - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine): + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): super(ReLUConvBN, self).__init__() self.op = nn.Sequential( nn.ReLU(inplace=False), nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), - nn.BatchNorm2d(C_out, affine=affine) + nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) ) def forward(self, x): @@ -77,12 +77,12 @@ class ResNetBasicblock(nn.Module): class POOLING(nn.Module): - def __init__(self, C_in, C_out, stride, mode, affine=True): + def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): super(POOLING, self).__init__() if C_in == C_out: self.preprocess = None else: - self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine) + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats) if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) @@ -127,7 +127,7 @@ class Zero(nn.Module): class FactorizedReduce(nn.Module): - def __init__(self, C_in, C_out, stride, affine): + def __init__(self, C_in, C_out, stride, affine, track_running_stats): super(FactorizedReduce, self).__init__() self.stride = stride self.C_in = C_in @@ -142,7 +142,7 @@ class FactorizedReduce(nn.Module): self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) else: raise ValueError('Invalid stride : {:}'.format(stride)) - self.bn = nn.BatchNorm2d(C_out, affine=affine) + self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) def forward(self, x): x = self.relu(x) diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index f5af162..121322e 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -11,7 +11,7 @@ from ..cell_operations import OPS class SearchCell(nn.Module): - def __init__(self, C_in, C_out, stride, max_nodes, op_names): + def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): super(SearchCell, self).__init__() self.op_names = deepcopy(op_names) @@ -23,9 +23,9 @@ class SearchCell(nn.Module): for j in range(i): node_str = '{:}<-{:}'.format(i, j) if j == 0: - xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] + xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] else: - xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names] + xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names] self.edges[ node_str ] = nn.ModuleList( xlists ) self.edge_keys = sorted(list(self.edges.keys())) self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} diff --git a/scripts-search/NAS-Bench-102/train-models.sh b/scripts-search/NAS-Bench-102/train-models.sh index b9ed9a5..d71714b 100644 --- a/scripts-search/NAS-Bench-102/train-models.sh +++ b/scripts-search/NAS-Bench-102/train-models.sh @@ -28,7 +28,7 @@ else mode=cover fi -OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ +OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \ --mode ${mode} --save_dir ${save_dir} --max_node 4 \ --use_less ${use_less} \ --datasets cifar10 cifar10 cifar100 ImageNet16-120 \ diff --git a/scripts-search/algos/DARTS-V2.sh b/scripts-search/algos/DARTS-V2.sh index ed1c847..2d21149 100644 --- a/scripts-search/algos/DARTS-V2.sh +++ b/scripts-search/algos/DARTS-V2.sh @@ -19,6 +19,7 @@ seed=$2 channel=16 num_cells=5 max_nodes=4 +space=nas-bench-102 if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then data_path="$TORCH_HOME/cifar.python" @@ -26,11 +27,12 @@ else data_path="$TORCH_HOME/cifar.python/ImageNet16" fi -save_dir=./output/cell-search-tiny/DARTS-V2-${dataset} +save_dir=./output/search-cell-${space}/DARTS-V2-${dataset} OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --dataset ${dataset} --data_path ${data_path} \ - --search_space_name aa-nas \ + --search_space_name ${space} \ + --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --workers 4 --print_freq 200 --rand_seed ${seed}