update NAS-Bench-102 baselines / support track_running_stats

This commit is contained in:
D-X-Y 2019-12-23 13:32:20 +11:00
parent 729ce136db
commit 2dc8dce6d3
9 changed files with 56 additions and 57 deletions

View File

@ -158,6 +158,8 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102. We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly. If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
Note that you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1` - [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1` - [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1` - [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`

View File

@ -135,6 +135,7 @@ def main(xargs):
'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None) 'space' : search_space}, None)
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
@ -211,10 +212,9 @@ def main(xargs):
if find_best: if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger) copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
with torch.no_grad(): with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()

View File

@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def _concat(xs): def _concat(xs):
@ -198,6 +199,7 @@ def main(xargs):
'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None) 'space' : search_space}, None)
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
@ -208,6 +210,11 @@ def main(xargs):
flop, param = get_model_infos(search_model, xshape) flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model)) #logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
@ -229,7 +236,7 @@ def main(xargs):
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@ -238,7 +245,8 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy # check the best accuracy
@ -271,29 +279,15 @@ def main(xargs):
copy_checkpoint(model_base_path, model_best_path, logger) copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad(): with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
logger.log('\n' + '-'*100) logger.log('\n' + '-'*100)
# check the performance from the architecture dataset # check the performance from the architecture dataset
""" logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
else:
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
geno = genotypes[total_epoch-1]
logger.log('The last model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
geno = genotypes['best']
logger.log('The best model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
"""
logger.close() logger.close()

View File

@ -1,6 +1,8 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
@ -15,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
@ -103,6 +106,7 @@ def main(xargs):
'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None) 'space' : search_space}, None)
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
@ -113,7 +117,12 @@ def main(xargs):
flop, param = get_model_infos(search_model, xshape) flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model)) #logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search_space : {:}'.format(search_space)) logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
@ -135,7 +144,7 @@ def main(xargs):
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@ -145,6 +154,7 @@ def main(xargs):
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
# check the best accuracy # check the best accuracy
@ -177,24 +187,15 @@ def main(xargs):
copy_checkpoint(model_base_path, model_best_path, logger) copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad(): with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
logger.log('\n' + '-'*100) logger.log('\n' + '-'*100)
# check the performance from the architecture dataset # check the performance from the architecture dataset
""" logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
else:
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
geno = genotypes[total_epoch-1]
logger.log('The last model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
"""
logger.close() logger.close()

View File

@ -19,9 +19,9 @@ class InferCell(nn.Module):
cur_innod = [] cur_innod = []
for (op_name, op_in) in node_info: for (op_name, op_in) in node_info:
if op_in == 0: if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True) layer = OPS[op_name](C_in , C_out, stride, True, True)
else: else:
layer = OPS[op_name](C_out, C_out, 1, True) layer = OPS[op_name](C_out, C_out, 1, True, True)
cur_index.append( len(self.layers) ) cur_index.append( len(self.layers) )
cur_innod.append( op_in ) cur_innod.append( op_in )
self.layers.append( layer ) self.layers.append( layer )

View File

@ -7,13 +7,13 @@ import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] __all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = { OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), 'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), 'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), 'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine), 'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine), 'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine), 'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), 'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
} }
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
@ -27,12 +27,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
class ReLUConvBN(nn.Module): class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__() super(ReLUConvBN, self).__init__()
self.op = nn.Sequential( self.op = nn.Sequential(
nn.ReLU(inplace=False), nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=affine) nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
) )
def forward(self, x): def forward(self, x):
@ -77,12 +77,12 @@ class ResNetBasicblock(nn.Module):
class POOLING(nn.Module): class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True): def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__() super(POOLING, self).__init__()
if C_in == C_out: if C_in == C_out:
self.preprocess = None self.preprocess = None
else: else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine) self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
@ -127,7 +127,7 @@ class Zero(nn.Module):
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine): def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__() super(FactorizedReduce, self).__init__()
self.stride = stride self.stride = stride
self.C_in = C_in self.C_in = C_in
@ -142,7 +142,7 @@ class FactorizedReduce(nn.Module):
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else: else:
raise ValueError('Invalid stride : {:}'.format(stride)) raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine) self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x): def forward(self, x):
x = self.relu(x) x = self.relu(x)

View File

@ -11,7 +11,7 @@ from ..cell_operations import OPS
class SearchCell(nn.Module): class SearchCell(nn.Module):
def __init__(self, C_in, C_out, stride, max_nodes, op_names): def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
super(SearchCell, self).__init__() super(SearchCell, self).__init__()
self.op_names = deepcopy(op_names) self.op_names = deepcopy(op_names)
@ -23,9 +23,9 @@ class SearchCell(nn.Module):
for j in range(i): for j in range(i):
node_str = '{:}<-{:}'.format(i, j) node_str = '{:}<-{:}'.format(i, j)
if j == 0: if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names] xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
else: else:
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names] xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
self.edges[ node_str ] = nn.ModuleList( xlists ) self.edges[ node_str ] = nn.ModuleList( xlists )
self.edge_keys = sorted(list(self.edges.keys())) self.edge_keys = sorted(list(self.edges.keys()))
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}

View File

@ -28,7 +28,7 @@ else
mode=cover mode=cover
fi fi
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \
--mode ${mode} --save_dir ${save_dir} --max_node 4 \ --mode ${mode} --save_dir ${save_dir} --max_node 4 \
--use_less ${use_less} \ --use_less ${use_less} \
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ --datasets cifar10 cifar10 cifar100 ImageNet16-120 \

View File

@ -19,6 +19,7 @@ seed=$2
channel=16 channel=16
num_cells=5 num_cells=5
max_nodes=4 max_nodes=4
space=nas-bench-102
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
data_path="$TORCH_HOME/cifar.python" data_path="$TORCH_HOME/cifar.python"
@ -26,11 +27,12 @@ else
data_path="$TORCH_HOME/cifar.python/ImageNet16" data_path="$TORCH_HOME/cifar.python/ImageNet16"
fi fi
save_dir=./output/cell-search-tiny/DARTS-V2-${dataset} save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \ OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ --save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name aa-nas \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}