update codes
This commit is contained in:
		| @@ -88,12 +88,17 @@ def main(xargs): | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   if xargs.model_config is None: | ||||
|     model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                                 'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                                 'space'    : search_space, | ||||
|                                 'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   else: | ||||
|     model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space'    : search_space, | ||||
|                                                     'affine'     : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|   logger.log('model-config : {:}'.format(model_config)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
| @@ -104,7 +109,7 @@ def main(xargs): | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   logger.log('search-space : {:}'.format(search_space)) | ||||
|   logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
| @@ -173,7 +178,7 @@ def main(xargs): | ||||
|       logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) | ||||
|       copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|     with torch.no_grad(): | ||||
|       logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|       logger.log('{:}'.format(search_model.show_alphas())) | ||||
|     if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) | ||||
|     # measure elapsed time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
| @@ -198,6 +203,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   parser.add_argument('--config_path',        type=str,   help='The path of the configuration.') | ||||
|   parser.add_argument('--model_config',       type=str,   help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user