Move str2bool to config_utils
This commit is contained in:
		
							
								
								
									
										1
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -36,6 +36,7 @@ jobs: | |||||||
|           python -m black ./lib/spaces -l 88 --check --diff --verbose |           python -m black ./lib/spaces -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/trade_models -l 88 --check --diff --verbose |           python -m black ./lib/trade_models -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/procedures -l 88 --check --diff --verbose |           python -m black ./lib/procedures -l 88 --check --diff --verbose | ||||||
|  |           python -m black ./lib/config_utils -l 88 --check --diff --verbose | ||||||
|  |  | ||||||
|       - name: Test Search Space |       - name: Test Search Space | ||||||
|         run: | |         run: | | ||||||
|   | |||||||
| @@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from config_utils import arg_str2bool | ||||||
| from procedures.q_exps import update_gpu | from procedures.q_exps import update_gpu | ||||||
| from procedures.q_exps import update_market | from procedures.q_exps import update_market | ||||||
| from procedures.q_exps import run_exp | from procedures.q_exps import run_exp | ||||||
| @@ -182,6 +183,12 @@ if __name__ == "__main__": | |||||||
|         help="The market indicator.", |         help="The market indicator.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") |     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--shared_dataset", | ||||||
|  |         type=arg_str2bool, | ||||||
|  |         default=False, | ||||||
|  |         help="Whether to share the dataset for all algorithms?", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." |         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||||
|     ) |     ) | ||||||
| @@ -189,9 +196,13 @@ if __name__ == "__main__": | |||||||
|         "--alg", |         "--alg", | ||||||
|         type=str, |         type=str, | ||||||
|         choices=list(alg2configs.keys()), |         choices=list(alg2configs.keys()), | ||||||
|  |         nargs="+", | ||||||
|         required=True, |         required=True, | ||||||
|         help="The algorithm name.", |         help="The algorithm name(s).", | ||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     main(args, alg2configs[args.alg]) |     if len(args.alg) == 1: | ||||||
|  |         main(args, alg2configs[args.alg[0]]) | ||||||
|  |     else: | ||||||
|  |         print("-") | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from config_utils import arg_str2bool | ||||||
| import qlib | import qlib | ||||||
| from qlib.config import REG_CN | from qlib.config import REG_CN | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
| @@ -184,16 +185,6 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     parser = argparse.ArgumentParser("Show Results") |     parser = argparse.ArgumentParser("Show Results") | ||||||
|  |  | ||||||
|     def str2bool(v): |  | ||||||
|         if isinstance(v, bool): |  | ||||||
|             return v |  | ||||||
|         elif v.lower() in ("yes", "true", "t", "y", "1"): |  | ||||||
|             return True |  | ||||||
|         elif v.lower() in ("no", "false", "f", "n", "0"): |  | ||||||
|             return False |  | ||||||
|         else: |  | ||||||
|             raise argparse.ArgumentTypeError("Boolean value expected.") |  | ||||||
|  |  | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--save_dir", |         "--save_dir", | ||||||
|         type=str, |         type=str, | ||||||
| @@ -203,7 +194,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--verbose", |         "--verbose", | ||||||
|         type=str2bool, |         type=arg_str2bool, | ||||||
|         default=False, |         default=False, | ||||||
|         help="Print detailed log information or not.", |         help="Print detailed log information or not.", | ||||||
|     ) |     ) | ||||||
| @@ -228,7 +219,7 @@ if __name__ == "__main__": | |||||||
|         info_dict["heads"], |         info_dict["heads"], | ||||||
|         info_dict["values"], |         info_dict["values"], | ||||||
|         info_dict["names"], |         info_dict["names"], | ||||||
|         space=14, |         space=18, | ||||||
|         verbose=True, |         verbose=True, | ||||||
|         sort_key=True, |         sort_key=True, | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -1,13 +1,19 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
| ################################################## | ################################################## | ||||||
| from .configure_utils    import load_config, dict2config, configure2str | # general config related functions | ||||||
| from .basic_args         import obtain_basic_args | from .config_utils import load_config, dict2config, configure2str | ||||||
| from .attention_args     import obtain_attention_args | # the args setting for different experiments | ||||||
| from .random_baseline    import obtain_RandomSearch_args | from .basic_args import obtain_basic_args | ||||||
| from .cls_kd_args        import obtain_cls_kd_args | from .attention_args import obtain_attention_args | ||||||
| from .cls_init_args      import obtain_cls_init_args | from .random_baseline import obtain_RandomSearch_args | ||||||
|  | from .cls_kd_args import obtain_cls_kd_args | ||||||
|  | from .cls_init_args import obtain_cls_init_args | ||||||
| from .search_single_args import obtain_search_single_args | from .search_single_args import obtain_search_single_args | ||||||
| from .search_args        import obtain_search_args | from .search_args import obtain_search_args | ||||||
|  |  | ||||||
| # for network pruning | # for network pruning | ||||||
| from .pruning_args       import obtain_pruning_args | from .pruning_args import obtain_pruning_args | ||||||
|  |  | ||||||
|  | # utils for args | ||||||
|  | from .args_utils import arg_str2bool | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								lib/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								lib/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | |||||||
|  | import argparse | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def arg_str2bool(v): | ||||||
|  |     if isinstance(v, bool): | ||||||
|  |         return v | ||||||
|  |     elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||||
|  |         return True | ||||||
|  |     elif v.lower() in ("no", "false", "f", "n", "0"): | ||||||
|  |         return False | ||||||
|  |     else: | ||||||
|  |         raise argparse.ArgumentTypeError("Boolean value expected.") | ||||||
| @@ -1,22 +1,32 @@ | |||||||
| import random, argparse | import random, argparse | ||||||
| from .share_args import add_shared_args | from .share_args import add_shared_args | ||||||
|  |  | ||||||
| def obtain_attention_args(): |  | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |  | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |  | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |  | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |  | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |  | ||||||
|   parser.add_argument('--att_channel' ,     type=int,                   help='.') |  | ||||||
|   parser.add_argument('--att_spatial' ,     type=str,                   help='.') |  | ||||||
|   parser.add_argument('--att_active'  ,     type=str,                   help='.') |  | ||||||
|   add_shared_args( parser ) |  | ||||||
|   # Optimization options |  | ||||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: | def obtain_attention_args(): | ||||||
|     args.rand_seed = random.randint(1, 100000) |     parser = argparse.ArgumentParser( | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |         description="Train a classification model on typical image classification datasets.", | ||||||
|   return args |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument("--att_channel", type=int, help=".") | ||||||
|  |     parser.add_argument("--att_spatial", type=str, help=".") | ||||||
|  |     parser.add_argument("--att_active", type=str, help=".") | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -4,21 +4,41 @@ | |||||||
| import random, argparse | import random, argparse | ||||||
| from .share_args import add_shared_args | from .share_args import add_shared_args | ||||||
|  |  | ||||||
| def obtain_basic_args(): |  | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |  | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |  | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |  | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |  | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |  | ||||||
|   parser.add_argument('--model_source',     type=str,  default='normal',help='The source of model defination.') |  | ||||||
|   parser.add_argument('--extra_model_path', type=str,  default=None,    help='The extra model ckp file (help to indicate the searched architecture).') |  | ||||||
|   add_shared_args( parser ) |  | ||||||
|   # Optimization options |  | ||||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: | def obtain_basic_args(): | ||||||
|     args.rand_seed = random.randint(1, 100000) |     parser = argparse.ArgumentParser( | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |         description="Train a classification model on typical image classification datasets.", | ||||||
|   return args |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_source", | ||||||
|  |         type=str, | ||||||
|  |         default="normal", | ||||||
|  |         help="The source of model defination.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--extra_model_path", | ||||||
|  |         type=str, | ||||||
|  |         default=None, | ||||||
|  |         help="The extra model ckp file (help to indicate the searched architecture).", | ||||||
|  |     ) | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -1,20 +1,32 @@ | |||||||
| import random, argparse | import random, argparse | ||||||
| from .share_args import add_shared_args | from .share_args import add_shared_args | ||||||
|  |  | ||||||
| def obtain_cls_init_args(): |  | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |  | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |  | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |  | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |  | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |  | ||||||
|   parser.add_argument('--init_checkpoint',  type=str,                   help='The checkpoint path to the initial model.') |  | ||||||
|   add_shared_args( parser ) |  | ||||||
|   # Optimization options |  | ||||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: | def obtain_cls_init_args(): | ||||||
|     args.rand_seed = random.randint(1, 100000) |     parser = argparse.ArgumentParser( | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |         description="Train a classification model on typical image classification datasets.", | ||||||
|   return args |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_checkpoint", type=str, help="The checkpoint path to the initial model." | ||||||
|  |     ) | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -1,23 +1,43 @@ | |||||||
| import random, argparse | import random, argparse | ||||||
| from .share_args import add_shared_args | from .share_args import add_shared_args | ||||||
|  |  | ||||||
| def obtain_cls_kd_args(): |  | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |  | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |  | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |  | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |  | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |  | ||||||
|   parser.add_argument('--KD_checkpoint',    type=str,                   help='The teacher checkpoint in knowledge distillation.') |  | ||||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') |  | ||||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') |  | ||||||
|   #parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') |  | ||||||
|   add_shared_args( parser ) |  | ||||||
|   # Optimization options |  | ||||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: | def obtain_cls_kd_args(): | ||||||
|     args.rand_seed = random.randint(1, 100000) |     parser = argparse.ArgumentParser( | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |         description="Train a classification model on typical image classification datasets.", | ||||||
|   return args |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--KD_checkpoint", | ||||||
|  |         type=str, | ||||||
|  |         help="The teacher checkpoint in knowledge distillation.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--KD_temperature", | ||||||
|  |         type=float, | ||||||
|  |         help="The temperature parameter in knowledge distillation.", | ||||||
|  |     ) | ||||||
|  |     # parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|  |     return args | ||||||
|   | |||||||
							
								
								
									
										135
									
								
								lib/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								lib/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | |||||||
|  | # Copyright (c) Facebook, Inc. and its affiliates. | ||||||
|  | # All rights reserved. | ||||||
|  | # | ||||||
|  | # This source code is licensed under the license found in the | ||||||
|  | # LICENSE file in the root directory of this source tree. | ||||||
|  | # | ||||||
|  | import os, json | ||||||
|  | from os import path as osp | ||||||
|  | from pathlib import Path | ||||||
|  | from collections import namedtuple | ||||||
|  |  | ||||||
|  | support_types = ("str", "int", "bool", "float", "none") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_param(original_lists): | ||||||
|  |     assert isinstance(original_lists, list), "The type is not right : {:}".format( | ||||||
|  |         original_lists | ||||||
|  |     ) | ||||||
|  |     ctype, value = original_lists[0], original_lists[1] | ||||||
|  |     assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) | ||||||
|  |     is_list = isinstance(value, list) | ||||||
|  |     if not is_list: | ||||||
|  |         value = [value] | ||||||
|  |     outs = [] | ||||||
|  |     for x in value: | ||||||
|  |         if ctype == "int": | ||||||
|  |             x = int(x) | ||||||
|  |         elif ctype == "str": | ||||||
|  |             x = str(x) | ||||||
|  |         elif ctype == "bool": | ||||||
|  |             x = bool(int(x)) | ||||||
|  |         elif ctype == "float": | ||||||
|  |             x = float(x) | ||||||
|  |         elif ctype == "none": | ||||||
|  |             if x.lower() != "none": | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "For the none type, the value must be none instead of {:}".format(x) | ||||||
|  |                 ) | ||||||
|  |             x = None | ||||||
|  |         else: | ||||||
|  |             raise TypeError("Does not know this type : {:}".format(ctype)) | ||||||
|  |         outs.append(x) | ||||||
|  |     if not is_list: | ||||||
|  |         outs = outs[0] | ||||||
|  |     return outs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_config(path, extra, logger): | ||||||
|  |     path = str(path) | ||||||
|  |     if hasattr(logger, "log"): | ||||||
|  |         logger.log(path) | ||||||
|  |     assert os.path.exists(path), "Can not find {:}".format(path) | ||||||
|  |     # Reading data back | ||||||
|  |     with open(path, "r") as f: | ||||||
|  |         data = json.load(f) | ||||||
|  |     content = {k: convert_param(v) for k, v in data.items()} | ||||||
|  |     assert extra is None or isinstance( | ||||||
|  |         extra, dict | ||||||
|  |     ), "invalid type of extra : {:}".format(extra) | ||||||
|  |     if isinstance(extra, dict): | ||||||
|  |         content = {**content, **extra} | ||||||
|  |     Arguments = namedtuple("Configure", " ".join(content.keys())) | ||||||
|  |     content = Arguments(**content) | ||||||
|  |     if hasattr(logger, "log"): | ||||||
|  |         logger.log("{:}".format(content)) | ||||||
|  |     return content | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def configure2str(config, xpath=None): | ||||||
|  |     if not isinstance(config, dict): | ||||||
|  |         config = config._asdict() | ||||||
|  |  | ||||||
|  |     def cstring(x): | ||||||
|  |         return '"{:}"'.format(x) | ||||||
|  |  | ||||||
|  |     def gtype(x): | ||||||
|  |         if isinstance(x, list): | ||||||
|  |             x = x[0] | ||||||
|  |         if isinstance(x, str): | ||||||
|  |             return "str" | ||||||
|  |         elif isinstance(x, bool): | ||||||
|  |             return "bool" | ||||||
|  |         elif isinstance(x, int): | ||||||
|  |             return "int" | ||||||
|  |         elif isinstance(x, float): | ||||||
|  |             return "float" | ||||||
|  |         elif x is None: | ||||||
|  |             return "none" | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid : {:}".format(x)) | ||||||
|  |  | ||||||
|  |     def cvalue(x, xtype): | ||||||
|  |         if isinstance(x, list): | ||||||
|  |             is_list = True | ||||||
|  |         else: | ||||||
|  |             is_list, x = False, [x] | ||||||
|  |         temps = [] | ||||||
|  |         for temp in x: | ||||||
|  |             if xtype == "bool": | ||||||
|  |                 temp = cstring(int(temp)) | ||||||
|  |             elif xtype == "none": | ||||||
|  |                 temp = cstring("None") | ||||||
|  |             else: | ||||||
|  |                 temp = cstring(temp) | ||||||
|  |             temps.append(temp) | ||||||
|  |         if is_list: | ||||||
|  |             return "[{:}]".format(", ".join(temps)) | ||||||
|  |         else: | ||||||
|  |             return temps[0] | ||||||
|  |  | ||||||
|  |     xstrings = [] | ||||||
|  |     for key, value in config.items(): | ||||||
|  |         xtype = gtype(value) | ||||||
|  |         string = "  {:20s} : [{:8s}, {:}]".format( | ||||||
|  |             cstring(key), cstring(xtype), cvalue(value, xtype) | ||||||
|  |         ) | ||||||
|  |         xstrings.append(string) | ||||||
|  |     Fstring = "{\n" + ",\n".join(xstrings) + "\n}" | ||||||
|  |     if xpath is not None: | ||||||
|  |         parent = Path(xpath).resolve().parent | ||||||
|  |         parent.mkdir(parents=True, exist_ok=True) | ||||||
|  |         if osp.isfile(xpath): | ||||||
|  |             os.remove(xpath) | ||||||
|  |         with open(xpath, "w") as text_file: | ||||||
|  |             text_file.write("{:}".format(Fstring)) | ||||||
|  |     return Fstring | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def dict2config(xdict, logger): | ||||||
|  |     assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) | ||||||
|  |     Arguments = namedtuple("Configure", " ".join(xdict.keys())) | ||||||
|  |     content = Arguments(**xdict) | ||||||
|  |     if hasattr(logger, "log"): | ||||||
|  |         logger.log("{:}".format(content)) | ||||||
|  |     return content | ||||||
| @@ -1,106 +0,0 @@ | |||||||
| # Copyright (c) Facebook, Inc. and its affiliates. |  | ||||||
| # All rights reserved. |  | ||||||
| # |  | ||||||
| # This source code is licensed under the license found in the |  | ||||||
| # LICENSE file in the root directory of this source tree. |  | ||||||
| # |  | ||||||
| import os, json |  | ||||||
| from os import path as osp |  | ||||||
| from pathlib import Path |  | ||||||
| from collections import namedtuple |  | ||||||
|  |  | ||||||
| support_types = ('str', 'int', 'bool', 'float', 'none') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_param(original_lists): |  | ||||||
|   assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) |  | ||||||
|   ctype, value = original_lists[0], original_lists[1] |  | ||||||
|   assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) |  | ||||||
|   is_list = isinstance(value, list) |  | ||||||
|   if not is_list: value = [value] |  | ||||||
|   outs = [] |  | ||||||
|   for x in value: |  | ||||||
|     if ctype == 'int': |  | ||||||
|       x = int(x) |  | ||||||
|     elif ctype == 'str': |  | ||||||
|       x = str(x) |  | ||||||
|     elif ctype == 'bool': |  | ||||||
|       x = bool(int(x)) |  | ||||||
|     elif ctype == 'float': |  | ||||||
|       x = float(x) |  | ||||||
|     elif ctype == 'none': |  | ||||||
|       if x.lower() != 'none': |  | ||||||
|         raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) |  | ||||||
|       x = None |  | ||||||
|     else: |  | ||||||
|       raise TypeError('Does not know this type : {:}'.format(ctype)) |  | ||||||
|     outs.append(x) |  | ||||||
|   if not is_list: outs = outs[0] |  | ||||||
|   return outs |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_config(path, extra, logger): |  | ||||||
|   path = str(path) |  | ||||||
|   if hasattr(logger, 'log'): logger.log(path) |  | ||||||
|   assert os.path.exists(path), 'Can not find {:}'.format(path) |  | ||||||
|   # Reading data back |  | ||||||
|   with open(path, 'r') as f: |  | ||||||
|     data = json.load(f) |  | ||||||
|   content = { k: convert_param(v) for k,v in data.items()} |  | ||||||
|   assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) |  | ||||||
|   if isinstance(extra, dict): content = {**content, **extra} |  | ||||||
|   Arguments = namedtuple('Configure', ' '.join(content.keys())) |  | ||||||
|   content   = Arguments(**content) |  | ||||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) |  | ||||||
|   return content |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def configure2str(config, xpath=None): |  | ||||||
|   if not isinstance(config, dict): |  | ||||||
|     config = config._asdict() |  | ||||||
|   def cstring(x): |  | ||||||
|     return "\"{:}\"".format(x) |  | ||||||
|   def gtype(x): |  | ||||||
|     if isinstance(x, list): x = x[0] |  | ||||||
|     if isinstance(x, str)  : return 'str' |  | ||||||
|     elif isinstance(x, bool) : return 'bool' |  | ||||||
|     elif isinstance(x, int): return 'int' |  | ||||||
|     elif isinstance(x, float): return 'float' |  | ||||||
|     elif x is None           : return 'none' |  | ||||||
|     else: raise ValueError('invalid : {:}'.format(x)) |  | ||||||
|   def cvalue(x, xtype): |  | ||||||
|     if isinstance(x, list): is_list = True |  | ||||||
|     else: |  | ||||||
|       is_list, x = False, [x] |  | ||||||
|     temps = [] |  | ||||||
|     for temp in x: |  | ||||||
|       if xtype == 'bool'  : temp = cstring(int(temp)) |  | ||||||
|       elif xtype == 'none': temp = cstring('None') |  | ||||||
|       else                : temp = cstring(temp) |  | ||||||
|       temps.append( temp ) |  | ||||||
|     if is_list: |  | ||||||
|       return "[{:}]".format( ', '.join( temps ) ) |  | ||||||
|     else: |  | ||||||
|       return temps[0] |  | ||||||
|  |  | ||||||
|   xstrings = [] |  | ||||||
|   for key, value in config.items(): |  | ||||||
|     xtype  = gtype(value) |  | ||||||
|     string = '  {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) |  | ||||||
|     xstrings.append(string) |  | ||||||
|   Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' |  | ||||||
|   if xpath is not None: |  | ||||||
|     parent = Path(xpath).resolve().parent |  | ||||||
|     parent.mkdir(parents=True, exist_ok=True) |  | ||||||
|     if osp.isfile(xpath): os.remove(xpath) |  | ||||||
|     with open(xpath, "w") as text_file: |  | ||||||
|       text_file.write('{:}'.format(Fstring)) |  | ||||||
|   return Fstring |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def dict2config(xdict, logger): |  | ||||||
|   assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) ) |  | ||||||
|   Arguments = namedtuple('Configure', ' '.join(xdict.keys())) |  | ||||||
|   content   = Arguments(**xdict) |  | ||||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) |  | ||||||
|   return content |  | ||||||
| @@ -1,26 +1,48 @@ | |||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
| from .share_args import add_shared_args | from .share_args import add_shared_args | ||||||
|  |  | ||||||
| def obtain_pruning_args(): |  | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |  | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |  | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |  | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |  | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |  | ||||||
|   parser.add_argument('--keep_ratio'  ,     type=float,                 help='The left channel ratio compared to the original network.') |  | ||||||
|   parser.add_argument('--model_version',    type=str,                   help='The network version.') |  | ||||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') |  | ||||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') |  | ||||||
|   parser.add_argument('--Regular_W_feat',   type=float,                 help='The .') |  | ||||||
|   parser.add_argument('--Regular_W_conv',   type=float,                 help='The .') |  | ||||||
|   add_shared_args( parser ) |  | ||||||
|   # Optimization options |  | ||||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: | def obtain_pruning_args(): | ||||||
|     args.rand_seed = random.randint(1, 100000) |     parser = argparse.ArgumentParser( | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |         description="Train a classification model on typical image classification datasets.", | ||||||
|   assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|   return args |     ) | ||||||
|  |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|  |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--keep_ratio", | ||||||
|  |         type=float, | ||||||
|  |         help="The left channel ratio compared to the original network.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--model_version", type=str, help="The network version.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--KD_temperature", | ||||||
|  |         type=float, | ||||||
|  |         help="The temperature parameter in knowledge distillation.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--Regular_W_feat", type=float, help="The .") | ||||||
|  |     parser.add_argument("--Regular_W_conv", type=float, help="The .") | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|  |     assert ( | ||||||
|  |         args.keep_ratio > 0 and args.keep_ratio <= 1 | ||||||
|  |     ), "invalid keep ratio : {:}".format(args.keep_ratio) | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -3,22 +3,42 @@ from .share_args import add_shared_args | |||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_RandomSearch_args(): | def obtain_RandomSearch_args(): | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |     parser = argparse.ArgumentParser( | ||||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') |         description="Train a classification model on typical image classification datasets.", | ||||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|   parser.add_argument('--expect_flop',      type=float,                 help='The expected flop keep ratio.') |     ) | ||||||
|   parser.add_argument('--arch_nums'   ,     type=int,                   help='The maximum number of running random arch generating..') |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') |     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |     parser.add_argument( | ||||||
|   parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') |         "--expect_flop", type=float, help="The expected flop keep ratio." | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |     ) | ||||||
|   add_shared_args( parser ) |     parser.add_argument( | ||||||
|   # Optimization options |         "--arch_nums", | ||||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') |         type=int, | ||||||
|   args = parser.parse_args() |         help="The maximum number of running random arch generating..", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--random_mode", | ||||||
|  |         type=str, | ||||||
|  |         choices=["random", "fix"], | ||||||
|  |         help="The path to the optimizer configuration", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|     args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|   #assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) |     # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||||
|   return args |     return args | ||||||
|   | |||||||
| @@ -3,30 +3,51 @@ from .share_args import add_shared_args | |||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_search_args(): | def obtain_search_args(): | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |     parser = argparse.ArgumentParser( | ||||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') |         description="Train a classification model on typical image classification datasets.", | ||||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') |     ) | ||||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') |     parser.add_argument( | ||||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') |     ) | ||||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') |     parser.add_argument( | ||||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') |     ) | ||||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') |     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||||
|   # ablation studies |     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||||
|   parser.add_argument('--ablation_num_select', type=int,                help='The number of randomly selected channels.') |     parser.add_argument( | ||||||
|   add_shared_args( parser ) |         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||||
|   # Optimization options |     ) | ||||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') |     parser.add_argument( | ||||||
|   args = parser.parse_args() |         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||||
|  |     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||||
|  |     ) | ||||||
|  |     # ablation studies | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--ablation_num_select", | ||||||
|  |         type=int, | ||||||
|  |         help="The number of randomly selected channels.", | ||||||
|  |     ) | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|     args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None |     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) |     assert ( | ||||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) |         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||||
|   #args.arch_para_pure = bool(args.arch_para_pure) |     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||||
|   return args |     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||||
|  |     # args.arch_para_pure = bool(args.arch_para_pure) | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -3,29 +3,46 @@ from .share_args import add_shared_args | |||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_search_single_args(): | def obtain_search_single_args(): | ||||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |     parser = argparse.ArgumentParser( | ||||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') |         description="Train a classification model on typical image classification datasets.", | ||||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') |     ) | ||||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') |     parser.add_argument("--resume", type=str, help="Resume path.") | ||||||
|   parser.add_argument('--search_shape'  ,   type=str,                   help='The shape to be searched.') |     parser.add_argument( | ||||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') |         "--model_config", type=str, help="The path to the model configuration" | ||||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') |     ) | ||||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') |     parser.add_argument( | ||||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') |         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') |     ) | ||||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') |     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') |     parser.add_argument("--search_shape", type=str, help="The shape to be searched.") | ||||||
|   add_shared_args( parser ) |     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||||
|   # Optimization options |     parser.add_argument( | ||||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') |         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||||
|   args = parser.parse_args() |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||||
|  |     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||||
|  |     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||||
|  |     ) | ||||||
|  |     add_shared_args(parser) | ||||||
|  |     # Optimization options | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", type=int, default=2, help="Batch size for training." | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|   if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|     args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|   assert args.save_dir is not None, 'save-path argument can not be None' |     assert args.save_dir is not None, "save-path argument can not be None" | ||||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None |     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) |     assert ( | ||||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) |         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||||
|   #args.arch_para_pure = bool(args.arch_para_pure) |     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||||
|   return args |     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||||
|  |     # args.arch_para_pure = bool(args.arch_para_pure) | ||||||
|  |     return args | ||||||
|   | |||||||
| @@ -1,17 +1,39 @@ | |||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
|  |  | ||||||
| def add_shared_args( parser ): |  | ||||||
|   # Data Generation | def add_shared_args(parser): | ||||||
|   parser.add_argument('--dataset',          type=str,                   help='The dataset name.') |     # Data Generation | ||||||
|   parser.add_argument('--data_path',        type=str,                   help='The dataset name.') |     parser.add_argument("--dataset", type=str, help="The dataset name.") | ||||||
|   parser.add_argument('--cutout_length',    type=int,                   help='The cutout length, negative means not use.') |     parser.add_argument("--data_path", type=str, help="The dataset name.") | ||||||
|   # Printing |     parser.add_argument( | ||||||
|   parser.add_argument('--print_freq',       type=int,   default=100,    help='print frequency (default: 200)') |         "--cutout_length", type=int, help="The cutout length, negative means not use." | ||||||
|   parser.add_argument('--print_freq_eval',  type=int,   default=100,    help='print frequency (default: 200)') |     ) | ||||||
|   # Checkpoints |     # Printing | ||||||
|   parser.add_argument('--eval_frequency',   type=int,   default=1,      help='evaluation frequency (default: 200)') |     parser.add_argument( | ||||||
|   parser.add_argument('--save_dir',         type=str,                   help='Folder to save checkpoints and log.') |         "--print_freq", type=int, default=100, help="print frequency (default: 200)" | ||||||
|   # Acceleration |     ) | ||||||
|   parser.add_argument('--workers',          type=int,   default=8,      help='number of data loading workers (default: 8)') |     parser.add_argument( | ||||||
|   # Random Seed |         "--print_freq_eval", | ||||||
|   parser.add_argument('--rand_seed',        type=int,   default=-1,     help='manual seed') |         type=int, | ||||||
|  |         default=100, | ||||||
|  |         help="print frequency (default: 200)", | ||||||
|  |     ) | ||||||
|  |     # Checkpoints | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--eval_frequency", | ||||||
|  |         type=int, | ||||||
|  |         default=1, | ||||||
|  |         help="evaluation frequency (default: 200)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||||
|  |     ) | ||||||
|  |     # Acceleration | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         help="number of data loading workers (default: 8)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user