update NAS-Bench
This commit is contained in:
		| @@ -3,11 +3,9 @@ | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| @@ -107,7 +105,6 @@ def main(xargs): | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
|   logger.log('criterion   : {:}'.format(criterion)) | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import sys, time, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| @@ -93,8 +93,7 @@ def get_best_arch(xloader, network, n_samples): | ||||
|       _, logits = network(inputs) | ||||
|       val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5)) | ||||
|  | ||||
|       valid_accs.append( val_top1.item() ) | ||||
|       #print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1)) | ||||
|       valid_accs.append(val_top1.item()) | ||||
|  | ||||
|     best_idx = np.argmax(valid_accs) | ||||
|     best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
| @@ -142,10 +141,13 @@ def main(xargs): | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   if xargs.model_config is None: | ||||
|     model_config = dict2config( | ||||
|       dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, | ||||
|            space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) | ||||
|   else: | ||||
|     model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False, | ||||
|                                                         track_running_stats=bool(xargs.track_running_stats)), None) | ||||
|   logger.log('search space : {:}'.format(search_space)) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|    | ||||
| @@ -156,7 +158,6 @@ def main(xargs): | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
|   logger.log('criterion   : {:}'.format(criterion)) | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   logger.log('search-space : {:}'.format(search_space)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
| @@ -233,7 +234,7 @@ def main(xargs): | ||||
|           'last_checkpoint': save_path, | ||||
|           }, logger.path('info'), logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|       logger.log('{:}'.format(search_model.show_alphas())) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user