update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
parent
729ce136db
commit
2dc8dce6d3
@ -158,6 +158,8 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_
|
|||||||
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
|
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
|
||||||
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
|
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
|
||||||
|
|
||||||
|
Note that you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
|
||||||
|
|
||||||
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
|
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`
|
||||||
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
|
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
|
||||||
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
|
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
|
||||||
|
@ -135,6 +135,7 @@ def main(xargs):
|
|||||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
'space' : search_space}, None)
|
'space' : search_space}, None)
|
||||||
search_model = get_cell_based_tiny_net(model_config)
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
logger.log('search-model :\n{:}'.format(search_model))
|
||||||
|
|
||||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
@ -211,10 +212,9 @@ def main(xargs):
|
|||||||
if find_best:
|
if find_best:
|
||||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
if api is not None:
|
|
||||||
logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
from nas_102_api import NASBench102API as API
|
||||||
|
|
||||||
|
|
||||||
def _concat(xs):
|
def _concat(xs):
|
||||||
@ -198,6 +199,7 @@ def main(xargs):
|
|||||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
'space' : search_space}, None)
|
'space' : search_space}, None)
|
||||||
search_model = get_cell_based_tiny_net(model_config)
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
logger.log('search-model :\n{:}'.format(search_model))
|
||||||
|
|
||||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
@ -208,6 +210,11 @@ def main(xargs):
|
|||||||
flop, param = get_model_infos(search_model, xshape)
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
#logger.log('{:}'.format(search_model))
|
#logger.log('{:}'.format(search_model))
|
||||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
if xargs.arch_nas_dataset is None:
|
||||||
|
api = None
|
||||||
|
else:
|
||||||
|
api = API(xargs.arch_nas_dataset)
|
||||||
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||||
|
|
||||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
@ -229,7 +236,7 @@ def main(xargs):
|
|||||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
# start training
|
# start training
|
||||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||||
for epoch in range(start_epoch, total_epoch):
|
for epoch in range(start_epoch, total_epoch):
|
||||||
w_scheduler.update(epoch, 0.0)
|
w_scheduler.update(epoch, 0.0)
|
||||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
@ -238,7 +245,8 @@ def main(xargs):
|
|||||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))
|
||||||
|
|
||||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
search_time.update(time.time() - start_time)
|
||||||
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
# check the best accuracy
|
# check the best accuracy
|
||||||
@ -271,29 +279,15 @@ def main(xargs):
|
|||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
# check the performance from the architecture dataset
|
# check the performance from the architecture dataset
|
||||||
"""
|
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
||||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
|
||||||
else:
|
|
||||||
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
|
||||||
geno = genotypes[total_epoch-1]
|
|
||||||
logger.log('The last model is {:}'.format(geno))
|
|
||||||
info = nas_bench.query_by_arch( geno )
|
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
|
||||||
else : logger.log('{:}'.format(info))
|
|
||||||
logger.log('-'*100)
|
|
||||||
geno = genotypes['best']
|
|
||||||
logger.log('The best model is {:}'.format(geno))
|
|
||||||
info = nas_bench.query_by_arch( geno )
|
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
|
||||||
else : logger.log('{:}'.format(info))
|
|
||||||
"""
|
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
###########################################################################
|
||||||
|
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||||
|
###########################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -15,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
from nas_102_api import NASBench102API as API
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
@ -103,6 +106,7 @@ def main(xargs):
|
|||||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
'space' : search_space}, None)
|
'space' : search_space}, None)
|
||||||
search_model = get_cell_based_tiny_net(model_config)
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
logger.log('search-model :\n{:}'.format(search_model))
|
||||||
|
|
||||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
@ -113,7 +117,12 @@ def main(xargs):
|
|||||||
flop, param = get_model_infos(search_model, xshape)
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
#logger.log('{:}'.format(search_model))
|
#logger.log('{:}'.format(search_model))
|
||||||
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
logger.log('search_space : {:}'.format(search_space))
|
logger.log('search-space : {:}'.format(search_space))
|
||||||
|
if xargs.arch_nas_dataset is None:
|
||||||
|
api = None
|
||||||
|
else:
|
||||||
|
api = API(xargs.arch_nas_dataset)
|
||||||
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||||
|
|
||||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
@ -135,7 +144,7 @@ def main(xargs):
|
|||||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
# start training
|
# start training
|
||||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||||
for epoch in range(start_epoch, total_epoch):
|
for epoch in range(start_epoch, total_epoch):
|
||||||
w_scheduler.update(epoch, 0.0)
|
w_scheduler.update(epoch, 0.0)
|
||||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
@ -145,6 +154,7 @@ def main(xargs):
|
|||||||
|
|
||||||
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
|
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
|
||||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
|
search_time.update(time.time() - start_time)
|
||||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
|
||||||
# check the best accuracy
|
# check the best accuracy
|
||||||
@ -177,24 +187,15 @@ def main(xargs):
|
|||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
# check the performance from the architecture dataset
|
# check the performance from the architecture dataset
|
||||||
"""
|
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||||
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
||||||
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
|
||||||
else:
|
|
||||||
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
|
||||||
geno = genotypes[total_epoch-1]
|
|
||||||
logger.log('The last model is {:}'.format(geno))
|
|
||||||
info = nas_bench.query_by_arch( geno )
|
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
|
||||||
else : logger.log('{:}'.format(info))
|
|
||||||
logger.log('-'*100)
|
|
||||||
"""
|
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ class InferCell(nn.Module):
|
|||||||
cur_innod = []
|
cur_innod = []
|
||||||
for (op_name, op_in) in node_info:
|
for (op_name, op_in) in node_info:
|
||||||
if op_in == 0:
|
if op_in == 0:
|
||||||
layer = OPS[op_name](C_in , C_out, stride, True)
|
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||||
else:
|
else:
|
||||||
layer = OPS[op_name](C_out, C_out, 1, True)
|
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||||
cur_index.append( len(self.layers) )
|
cur_index.append( len(self.layers) )
|
||||||
cur_innod.append( op_in )
|
cur_innod.append( op_in )
|
||||||
self.layers.append( layer )
|
self.layers.append( layer )
|
||||||
|
@ -7,13 +7,13 @@ import torch.nn as nn
|
|||||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||||
|
|
||||||
OPS = {
|
OPS = {
|
||||||
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
|
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
|
||||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
|
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
|
||||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine),
|
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine),
|
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine),
|
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||||
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||||
}
|
}
|
||||||
|
|
||||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||||
@ -27,12 +27,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
|||||||
|
|
||||||
class ReLUConvBN(nn.Module):
|
class ReLUConvBN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine):
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||||
super(ReLUConvBN, self).__init__()
|
super(ReLUConvBN, self).__init__()
|
||||||
self.op = nn.Sequential(
|
self.op = nn.Sequential(
|
||||||
nn.ReLU(inplace=False),
|
nn.ReLU(inplace=False),
|
||||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||||
nn.BatchNorm2d(C_out, affine=affine)
|
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -77,12 +77,12 @@ class ResNetBasicblock(nn.Module):
|
|||||||
|
|
||||||
class POOLING(nn.Module):
|
class POOLING(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, mode, affine=True):
|
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||||
super(POOLING, self).__init__()
|
super(POOLING, self).__init__()
|
||||||
if C_in == C_out:
|
if C_in == C_out:
|
||||||
self.preprocess = None
|
self.preprocess = None
|
||||||
else:
|
else:
|
||||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine)
|
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats)
|
||||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||||
@ -127,7 +127,7 @@ class Zero(nn.Module):
|
|||||||
|
|
||||||
class FactorizedReduce(nn.Module):
|
class FactorizedReduce(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, affine):
|
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||||
super(FactorizedReduce, self).__init__()
|
super(FactorizedReduce, self).__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.C_in = C_in
|
self.C_in = C_in
|
||||||
@ -142,7 +142,7 @@ class FactorizedReduce(nn.Module):
|
|||||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
|
@ -11,7 +11,7 @@ from ..cell_operations import OPS
|
|||||||
|
|
||||||
class SearchCell(nn.Module):
|
class SearchCell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names):
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||||
super(SearchCell, self).__init__()
|
super(SearchCell, self).__init__()
|
||||||
|
|
||||||
self.op_names = deepcopy(op_names)
|
self.op_names = deepcopy(op_names)
|
||||||
@ -23,9 +23,9 @@ class SearchCell(nn.Module):
|
|||||||
for j in range(i):
|
for j in range(i):
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
if j == 0:
|
if j == 0:
|
||||||
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names]
|
xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names]
|
||||||
else:
|
else:
|
||||||
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names]
|
xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names]
|
||||||
self.edges[ node_str ] = nn.ModuleList( xlists )
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||||
self.edge_keys = sorted(list(self.edges.keys()))
|
self.edge_keys = sorted(list(self.edges.keys()))
|
||||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
@ -28,7 +28,7 @@ else
|
|||||||
mode=cover
|
mode=cover
|
||||||
fi
|
fi
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \
|
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \
|
||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
--use_less ${use_less} \
|
--use_less ${use_less} \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||||
|
@ -19,6 +19,7 @@ seed=$2
|
|||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
|
space=nas-bench-102
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -26,11 +27,12 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/cell-search-tiny/DARTS-V2-${dataset}
|
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name aa-nas \
|
--search_space_name ${space} \
|
||||||
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
Loading…
Reference in New Issue
Block a user