update NAS-Bench-102 baselines

This commit is contained in:
D-X-Y 2019-12-24 17:36:47 +11:00
parent af4212b4db
commit 44a0d51449
18 changed files with 105 additions and 110 deletions

View File

@ -1,10 +1,10 @@
{ {
"scheduler": ["str", "cos"], "scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"], "eta_min" : ["float", "0.001"],
"epochs" : ["int", "50"], "epochs" : ["int", "50"],
"warmup" : ["int", "0"], "warmup" : ["int", "0"],
"optim" : ["str", "SGD"], "optim" : ["str", "SGD"],
"LR" : ["float", "0.025"],
"decay" : ["float", "0.0005"], "decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"], "momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"], "nesterov" : ["bool", "1"],

View File

@ -2,7 +2,7 @@
"scheduler": ["str", "cos"], "scheduler": ["str", "cos"],
"LR" : ["float", "0.05"], "LR" : ["float", "0.05"],
"eta_min" : ["float", "0.0005"], "eta_min" : ["float", "0.0005"],
"epochs" : ["int", "310"], "epochs" : ["int", "250"],
"T_max" : ["int", "10"], "T_max" : ["int", "10"],
"warmup" : ["int", "0"], "warmup" : ["int", "0"],
"optim" : ["str", "SGD"], "optim" : ["str", "SGD"],

View File

@ -1,10 +1,10 @@
{ {
"scheduler": ["str", "cos"], "scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"], "eta_min" : ["float", "0.001"],
"epochs" : ["int", "250"], "epochs" : ["int", "250"],
"warmup" : ["int", "0"], "warmup" : ["int", "0"],
"optim" : ["str", "SGD"], "optim" : ["str", "SGD"],
"LR" : ["float", "0.025"],
"decay" : ["float", "0.0005"], "decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"], "momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"], "nesterov" : ["bool", "1"],

View File

@ -1,10 +1,10 @@
{ {
"scheduler": ["str", "cos"], "scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"], "eta_min" : ["float", "0.001"],
"epochs" : ["int", "250"], "epochs" : ["int", "250"],
"warmup" : ["int", "0"], "warmup" : ["int", "0"],
"optim" : ["str", "SGD"], "optim" : ["str", "SGD"],
"LR" : ["float", "0.025"],
"decay" : ["float", "0.0005"], "decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"], "momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"], "nesterov" : ["bool", "1"],

View File

@ -1,10 +1,10 @@
{ {
"scheduler": ["str", "cos"], "scheduler": ["str", "cos"],
"LR" : ["float", "0.025"],
"eta_min" : ["float", "0.001"], "eta_min" : ["float", "0.001"],
"epochs" : ["int", "250"], "epochs" : ["int", "250"],
"warmup" : ["int", "0"], "warmup" : ["int", "0"],
"optim" : ["str", "SGD"], "optim" : ["str", "SGD"],
"LR" : ["float", "0.025"],
"decay" : ["float", "0.0005"], "decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"], "momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"], "nesterov" : ["bool", "1"],

View File

@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
@ -224,6 +225,12 @@ def main(xargs):
#flop, param = get_model_infos(shared_cnn, xshape) #flop, param = get_model_infos(shared_cnn, xshape)
#logger.log('{:}'.format(shared_cnn)) #logger.log('{:}'.format(shared_cnn))
#logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
@ -247,7 +254,7 @@ def main(xargs):
start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@ -263,7 +270,8 @@ def main(xargs):
'ctl_entropy_w': xargs.controller_entropy_weight, 'ctl_entropy_w': xargs.controller_entropy_weight,
'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \
epoch_str, xargs.print_freq, logger) epoch_str, xargs.print_freq, logger)
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline)) search_time.update(time.time() - start_time)
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
shared_cnn.module.update_arch(best_arch) shared_cnn.module.update_arch(best_arch)
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
@ -298,6 +306,7 @@ def main(xargs):
if find_best: if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
copy_checkpoint(model_base_path, model_best_path, logger) copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
@ -306,27 +315,15 @@ def main(xargs):
logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
start_time = time.time()
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
search_time.update(time.time() - start_time)
shared_cnn.module.update_arch(final_arch) shared_cnn.module.update_arch(final_arch)
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
logger.log('The Selected Final Architecture : {:}'.format(final_arch)) logger.log('The Selected Final Architecture : {:}'.format(final_arch))
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
# check the performance from the architecture dataset logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
#else:
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
# geno = genotypes[total_epoch-1]
# logger.log('The last model is {:}'.format(geno))
# info = nas_bench.query_by_arch( geno )
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
# else : logger.log('{:}'.format(info))
# logger.log('-'*100)
# geno = genotypes['best']
# logger.log('The best model is {:}'.format(geno))
# info = nas_bench.query_by_arch( geno )
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
# else : logger.log('{:}'.format(info))
logger.close() logger.close()

View File

@ -93,8 +93,8 @@ def main(xargs):
logger.log('Load split file from {:}'.format(split_Fpath)) logger.log('Load split file from {:}'.format(split_Fpath))
else: else:
raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
config_path = 'configs/nas-benchmark/algos/GDAS.config' #config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader # data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
@ -105,7 +105,7 @@ def main(xargs):
model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space, 'space' : search_space,
'affine' : False, 'track_running_stats': True}, None) 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model)) logger.log('search-model :\n{:}'.format(search_model))
@ -156,7 +156,7 @@ def main(xargs):
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time) search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
# check the best accuracy # check the best accuracy
valid_accuracies[epoch] = valid_a_top1 valid_accuracies[epoch] = valid_a_top1
@ -210,6 +210,8 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
# architecture leraning rate # architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

View File

@ -15,6 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
@ -130,6 +131,9 @@ def main(xargs):
logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion)) logger.log('criterion : {:}'.format(criterion))
if xargs.arch_nas_dataset is None: api = None
else : api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
@ -149,7 +153,7 @@ def main(xargs):
start_epoch, valid_accuracies = 0, {'best': -1} start_epoch, valid_accuracies = 0, {'best': -1}
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@ -157,7 +161,8 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy # check the best accuracy
@ -188,7 +193,8 @@ def main(xargs):
start_time = time.time() start_time = time.time()
logger.log('\n' + '-'*200) logger.log('\n' + '-'*200)
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
start_time = time.time()
best_arch, best_acc = None, -1 best_arch, best_acc = None, -1
for iarch in range(xargs.select_num): for iarch in range(xargs.select_num):
arch = search_model.random_genotype( True ) arch = search_model.random_genotype( True )
@ -197,24 +203,10 @@ def main(xargs):
if best_arch is None or best_acc < valid_a_top1: if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1 best_arch, best_acc = arch, valid_a_top1
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc)) search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
logger.log('\n' + '-'*100) if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
"""
# check the performance from the architecture dataset
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
else:
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
geno = best_arch
logger.log('The last model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close() logger.close()
"""
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -52,14 +52,18 @@ def main(xargs, nas_bench):
random_arch = random_architecture_func(xargs.max_nodes, search_space) random_arch = random_architecture_func(xargs.max_nodes, search_space)
#x =random_arch() ; y = mutate_arch(x) #x =random_arch() ; y = mutate_arch(x)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc = None, -1 best_arch, best_acc, total_time_cost, history = None, -1, 0, []
for idx in range(xargs.random_num): #for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget:
arch = random_arch() arch = random_arch()
accuracy = train_and_eval(arch, nas_bench, extra_info) accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
if total_time_cost + cost_time > xargs.time_budget: break
else: total_time_cost += cost_time
history.append(arch)
if best_arch is None or best_acc < accuracy: if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch best_acc, best_arch = accuracy, arch
logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy)) logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc)) logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
info = nas_bench.query_by_arch( best_arch ) info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
@ -79,7 +83,8 @@ if __name__ == '__main__':
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') #parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log # log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')

View File

@ -60,12 +60,12 @@ def train_and_eval(arch, nas_bench, extra_info):
arch_index = nas_bench.query_index_by_arch( arch ) arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True) info = nas_bench.get_more_info(arch_index, 'cifar10-valid', True)
import pdb; pdb.set_trace() valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
else: else:
# train a model from scratch. # train a model from scratch.
raise ValueError('NOT IMPLEMENT YET') raise ValueError('NOT IMPLEMENT YET')
return valid_acc return valid_acc, time_cost
def random_architecture_func(max_nodes, op_names): def random_architecture_func(max_nodes, op_names):
@ -101,7 +101,7 @@ def mutate_arch_func(op_names):
return mutate_arch_func return mutate_arch_func
def regularized_evolution(cycles, population_size, sample_size, random_arch, mutate_arch, nas_bench, extra_info): def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info):
"""Algorithm for regularized evolution (i.e. aging evolution). """Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
@ -111,27 +111,30 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
cycles: the number of cycles the algorithm should run for. cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population. population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament. sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
Returns: Returns:
history: a list of `Model` instances, representing all the models computed history: a list of `Model` instances, representing all the models computed
during the evolution experiment. during the evolution experiment.
""" """
population = collections.deque() population = collections.deque()
history = [] # Not used by the algorithm, only used to report results. history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results.
# Initialize the population with random models. # Initialize the population with random models.
while len(population) < population_size: while len(population) < population_size:
model = Model() model = Model()
model.arch = random_arch() model.arch = random_arch()
model.accuracy = train_and_eval(model.arch, nas_bench, extra_info) model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info)
population.append(model) population.append(model)
history.append(model) history.append(model)
total_time_cost += time_cost
# Carry out evolution in cycles. Each cycle produces a model and removes # Carry out evolution in cycles. Each cycle produces a model and removes
# another. # another.
while len(history) < cycles: #while len(history) < cycles:
while total_time_cost < time_budget:
# Sample randomly chosen models from the current population. # Sample randomly chosen models from the current population.
sample = [] start_time, sample = time.time(), []
while len(sample) < sample_size: while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural # Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural # nets, the efficiency of this line is irrelevant because training neural
@ -145,13 +148,18 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut
# Create the child model and store it. # Create the child model and store it.
child = Model() child = Model()
child.arch = mutate_arch(parent.arch) child.arch = mutate_arch(parent.arch)
child.accuracy = train_and_eval(child.arch, nas_bench, extra_info) total_time_cost += time.time() - start_time
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info)
if total_time_cost + time_cost > time_budget: # return
return history, total_time_cost
else:
total_time_cost += time_cost
population.append(child) population.append(child)
history.append(child) history.append(child)
# Remove the oldest model. # Remove the oldest model.
population.popleft() population.popleft()
return history return history, total_time_cost
def main(xargs, nas_bench): def main(xargs, nas_bench):
@ -188,8 +196,9 @@ def main(xargs, nas_bench):
mutate_arch = mutate_arch_func(search_space) mutate_arch = mutate_arch_func(search_space)
#x =random_arch() ; y = mutate_arch(x) #x =random_arch() ; y = mutate_arch(x)
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history))) history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost))
best_arch = max(history, key=lambda i: i.accuracy) best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch best_arch = best_arch.arch
logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
@ -216,6 +225,7 @@ if __name__ == '__main__':
parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.') parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log # log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')

View File

@ -17,6 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces from models import get_cell_based_tiny_net, get_search_spaces
from nas_102_api import NASBench102API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
@ -162,7 +163,8 @@ def main(xargs):
search_space = get_search_spaces('cell', xargs.search_space_name) search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space}, None) 'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
logger.log('search space : {:}'.format(search_space)) logger.log('search space : {:}'.format(search_space))
search_model = get_cell_based_tiny_net(model_config) search_model = get_cell_based_tiny_net(model_config)
@ -175,6 +177,12 @@ def main(xargs):
flop, param = get_model_infos(search_model, xshape) flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model)) #logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
@ -196,7 +204,7 @@ def main(xargs):
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch): for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0) w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
@ -205,7 +213,8 @@ def main(xargs):
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
@ -243,52 +252,23 @@ def main(xargs):
}, logger.path('info'), logger) }, logger.path('info'), logger)
with torch.no_grad(): with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
#logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best'])) # the final post procedure : count the time
start_time = time.time()
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
network.module.set_cal_mode('dynamic', genotype) network.module.set_cal_mode('dynamic', genotype)
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1)) logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
# sampling
"""
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
selected_archs = set()
while len(selected_archs) < xargs.select_num:
architecture = search_model.dync_genotype()
selected_archs.add( architecture )
logger.log('select {:} architectures based on the learned arch-parameters'.format( len(selected_archs) ))
best_arch, best_acc = None, -1
state_dict = deepcopy( network.state_dict() )
for index, arch in enumerate(selected_archs):
with torch.no_grad():
search_model.set_cal_mode('dynamic', arch)
network.load_state_dict( deepcopy(state_dict) )
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('{:} [{:03d}/{:03d}] : {:125s}, loss={:.3f}, accuracy={:.3f}%'.format(time_string(), index, len(selected_archs), str(arch), valid_a_loss , valid_a_top1))
if best_arch is None or best_acc < valid_a_top1:
best_arch, best_acc = arch, valid_a_top1
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
"""
logger.log('\n' + '-'*100) logger.log('\n' + '-'*100)
# check the performance from the architecture dataset # check the performance from the architecture dataset
""" logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) ))
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
else:
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
geno = best_arch
logger.log('The last model is {:}'.format(geno))
info = nas_bench.query_by_arch( geno )
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
"""
logger.close() logger.close()
@ -303,7 +283,8 @@ if __name__ == '__main__':
parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.') parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
parser.add_argument('--config_path', type=str, help='.') parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
# architecture leraning rate # architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

View File

@ -20,6 +20,9 @@ def get_cell_based_tiny_net(config):
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
if super_type == 'basic' and config.name in group_names: if super_type == 'basic' and config.name in group_names:
from .cell_searchs import nas_super_nets from .cell_searchs import nas_super_nets
try:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
except:
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif super_type == 'l2s-base' and config.name in group_names: elif super_type == 'l2s-base' and config.name in group_names:
from .l2s_cell_searchs import nas_super_nets from .l2s_cell_searchs import nas_super_nets

View File

@ -11,7 +11,8 @@ from .genotypes import Structure
class TinyNetworkGDAS(nn.Module): class TinyNetworkGDAS(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True): #def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkGDAS, self).__init__() super(TinyNetworkGDAS, self).__init__()
self._C = C self._C = C
self._layerN = N self._layerN = N

View File

@ -13,7 +13,7 @@ from .genotypes import Structure
class TinyNetworkSETN(nn.Module): class TinyNetworkSETN(nn.Module):
def __init__(self, C, N, max_nodes, num_classes, search_space): def __init__(self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats):
super(TinyNetworkSETN, self).__init__() super(TinyNetworkSETN, self).__init__()
self._C = C self._C = C
self._layerN = N self._layerN = N
@ -31,7 +31,7 @@ class TinyNetworkSETN(nn.Module):
if reduction: if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2) cell = ResNetBasicblock(C_prev, C_curr, 2)
else: else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
self.cells.append( cell ) self.cells.append( cell )

View File

@ -34,6 +34,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--tau_max 10 --tau_min 0.1 \ --config_path configs/nas-benchmark/algos/GDAS.config \
--tau_max 10 --tau_min 0.1 --track_running_stats 1 \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -35,5 +35,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--ea_cycles 30 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \ --time_budget 12000 \
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}

View File

@ -34,5 +34,6 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
--dataset ${dataset} --data_path ${data_path} \ --dataset ${dataset} --data_path ${data_path} \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--random_num 100 \ --time_budget 12000 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}
# --random_num 100 \

View File

@ -36,6 +36,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
--search_space_name ${space} \ --search_space_name ${space} \
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ --arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
--config_path configs/nas-benchmark/algos/SETN.config \ --config_path configs/nas-benchmark/algos/SETN.config \
--track_running_stats 1 \
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \ --arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
--select_num 100 \ --select_num 100 \
--workers 4 --print_freq 200 --rand_seed ${seed} --workers 4 --print_freq 200 --rand_seed ${seed}