update NAS-Bench-102 baselines / support track_running_stats
This commit is contained in:
		| @@ -135,6 +135,7 @@ def main(xargs): | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
| @@ -211,10 +212,9 @@ def main(xargs): | ||||
|     if find_best: | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     if api is not None: | ||||
|       logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|   | ||||
| @@ -17,6 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
|  | ||||
|  | ||||
| def _concat(xs): | ||||
| @@ -198,6 +199,7 @@ def main(xargs): | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
| @@ -208,6 +210,11 @@ def main(xargs): | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
|     api = API(xargs.arch_nas_dataset) | ||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
| @@ -229,7 +236,7 @@ def main(xargs): | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||
|  | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     w_scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||
| @@ -238,7 +245,8 @@ def main(xargs): | ||||
|     logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) | ||||
|  | ||||
|     search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) | ||||
|     valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|     # check the best accuracy | ||||
| @@ -271,29 +279,15 @@ def main(xargs): | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   """ | ||||
|   if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): | ||||
|     logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) | ||||
|   else: | ||||
|     nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) | ||||
|     geno = genotypes[total_epoch-1] | ||||
|     logger.log('The last model is {:}'.format(geno)) | ||||
|     info = nas_bench.query_by_arch( geno ) | ||||
|     if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|     else           : logger.log('{:}'.format(info)) | ||||
|     logger.log('-'*100) | ||||
|     geno = genotypes['best'] | ||||
|     logger.log('The best model is {:}'.format(geno)) | ||||
|     info = nas_bench.query_by_arch( geno ) | ||||
|     if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|     else           : logger.log('{:}'.format(info)) | ||||
|   """ | ||||
|   logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| @@ -15,6 +17,7 @@ from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_che | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_102_api  import NASBench102API as API | ||||
|  | ||||
|  | ||||
| def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): | ||||
| @@ -103,6 +106,7 @@ def main(xargs): | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
| @@ -113,7 +117,12 @@ def main(xargs): | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   logger.log('search_space : {:}'.format(search_space)) | ||||
|   logger.log('search-space : {:}'.format(search_space)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
|     api = API(xargs.arch_nas_dataset) | ||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
| @@ -135,7 +144,7 @@ def main(xargs): | ||||
|     start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} | ||||
|  | ||||
|   # start training | ||||
|   start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup | ||||
|   start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup | ||||
|   for epoch in range(start_epoch, total_epoch): | ||||
|     w_scheduler.update(epoch, 0.0) | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) | ||||
| @@ -145,6 +154,7 @@ def main(xargs): | ||||
|  | ||||
|     search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ | ||||
|               = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) | ||||
|     logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 )) | ||||
|     # check the best accuracy | ||||
| @@ -177,24 +187,15 @@ def main(xargs): | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|  | ||||
|   logger.log('\n' + '-'*100) | ||||
|   # check the performance from the architecture dataset | ||||
|   """ | ||||
|   if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): | ||||
|     logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) | ||||
|   else: | ||||
|     nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) | ||||
|     geno = genotypes[total_epoch-1] | ||||
|     logger.log('The last model is {:}'.format(geno)) | ||||
|     info = nas_bench.query_by_arch( geno ) | ||||
|     if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) | ||||
|     else           : logger.log('{:}'.format(info)) | ||||
|     logger.log('-'*100) | ||||
|   """ | ||||
|   logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) | ||||
|   if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user