support first-order DARTS on the NASNet search space
This commit is contained in:
		| @@ -112,10 +112,14 @@ def main(xargs): | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   if xargs.model_config is None: | ||||
|     model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                                 'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                                 'space'    : search_space, | ||||
|                                 'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   else: | ||||
|     model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space'    : search_space, | ||||
|                                                     'affine'     : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
| @@ -213,12 +217,13 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   # channels and number-of-cells | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   parser.add_argument('--config_path',        type=str,   help='The config path.') | ||||
|   parser.add_argument('--model_config',       type=str,   help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user