simplify DARTS codes and update affine/track
This commit is contained in:
		| @@ -114,7 +114,8 @@ def main(xargs): | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
| @@ -217,6 +218,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   | ||||
| @@ -177,7 +177,8 @@ def main(xargs): | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
| @@ -282,6 +283,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   # architecture leraning rate | ||||
|   parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') | ||||
|   parser.add_argument('--arch_weight_decay',  type=float, default=1e-3, help='weight decay for arch encoding') | ||||
|   | ||||
| @@ -198,7 +198,8 @@ def main(xargs): | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   shared_cnn = get_cell_based_tiny_net(model_config) | ||||
|   controller = shared_cnn.create_controller() | ||||
|    | ||||
| @@ -319,6 +320,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset') | ||||
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') | ||||
|   # channels and number-of-cells | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.') | ||||
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.') | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   | ||||
| @@ -126,7 +126,8 @@ def main(xargs): | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|                               'space'    : search_space, | ||||
|                               'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) | ||||
| @@ -222,6 +223,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--channel',            type=int,   help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.') | ||||
|   parser.add_argument('--select_num',         type=int,   help='The number of selected architectures to evaluate.') | ||||
|   parser.add_argument('--track_running_stats',type=int,   choices=[0,1],help='Whether use track_running_stats or not in the BN layer.') | ||||
|   # log | ||||
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)') | ||||
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user