Move str2bool to config_utils
This commit is contained in:
		| @@ -1,13 +1,19 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .configure_utils    import load_config, dict2config, configure2str | ||||
| from .basic_args         import obtain_basic_args | ||||
| from .attention_args     import obtain_attention_args | ||||
| from .random_baseline    import obtain_RandomSearch_args | ||||
| from .cls_kd_args        import obtain_cls_kd_args | ||||
| from .cls_init_args      import obtain_cls_init_args | ||||
| # general config related functions | ||||
| from .config_utils import load_config, dict2config, configure2str | ||||
| # the args setting for different experiments | ||||
| from .basic_args import obtain_basic_args | ||||
| from .attention_args import obtain_attention_args | ||||
| from .random_baseline import obtain_RandomSearch_args | ||||
| from .cls_kd_args import obtain_cls_kd_args | ||||
| from .cls_init_args import obtain_cls_init_args | ||||
| from .search_single_args import obtain_search_single_args | ||||
| from .search_args        import obtain_search_args | ||||
| from .search_args import obtain_search_args | ||||
|  | ||||
| # for network pruning | ||||
| from .pruning_args       import obtain_pruning_args | ||||
| from .pruning_args import obtain_pruning_args | ||||
|  | ||||
| # utils for args | ||||
| from .args_utils import arg_str2bool | ||||
|   | ||||
							
								
								
									
										12
									
								
								lib/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								lib/config_utils/args_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| import argparse | ||||
|  | ||||
|  | ||||
| def arg_str2bool(v): | ||||
|     if isinstance(v, bool): | ||||
|         return v | ||||
|     elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||
|         return True | ||||
|     elif v.lower() in ("no", "false", "f", "n", "0"): | ||||
|         return False | ||||
|     else: | ||||
|         raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
| @@ -1,22 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_attention_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--att_channel' ,     type=int,                   help='.') | ||||
|   parser.add_argument('--att_spatial' ,     type=str,                   help='.') | ||||
|   parser.add_argument('--att_active'  ,     type=str,                   help='.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
| def obtain_attention_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--att_channel", type=int, help=".") | ||||
|     parser.add_argument("--att_spatial", type=str, help=".") | ||||
|     parser.add_argument("--att_active", type=str, help=".") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
|   | ||||
| @@ -4,21 +4,41 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_basic_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--model_source',     type=str,  default='normal',help='The source of model defination.') | ||||
|   parser.add_argument('--extra_model_path', type=str,  default=None,    help='The extra model ckp file (help to indicate the searched architecture).') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
| def obtain_basic_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--model_source", | ||||
|         type=str, | ||||
|         default="normal", | ||||
|         help="The source of model defination.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--extra_model_path", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The extra model ckp file (help to indicate the searched architecture).", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
|   | ||||
| @@ -1,20 +1,32 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_init_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--init_checkpoint',  type=str,                   help='The checkpoint path to the initial model.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
| def obtain_cls_init_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--init_checkpoint", type=str, help="The checkpoint path to the initial model." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
|   | ||||
| @@ -1,23 +1,43 @@ | ||||
| import random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_cls_kd_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--KD_checkpoint',    type=str,                   help='The teacher checkpoint in knowledge distillation.') | ||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') | ||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') | ||||
|   #parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   return args | ||||
| def obtain_cls_kd_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--KD_checkpoint", | ||||
|         type=str, | ||||
|         help="The teacher checkpoint in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     # parser.add_argument('--KD_feature',       type=float,                 help='Knowledge distillation at the feature level.') | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     return args | ||||
|   | ||||
							
								
								
									
										135
									
								
								lib/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								lib/config_utils/config_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ("str", "int", "bool", "float", "none") | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|     assert isinstance(original_lists, list), "The type is not right : {:}".format( | ||||
|         original_lists | ||||
|     ) | ||||
|     ctype, value = original_lists[0], original_lists[1] | ||||
|     assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) | ||||
|     is_list = isinstance(value, list) | ||||
|     if not is_list: | ||||
|         value = [value] | ||||
|     outs = [] | ||||
|     for x in value: | ||||
|         if ctype == "int": | ||||
|             x = int(x) | ||||
|         elif ctype == "str": | ||||
|             x = str(x) | ||||
|         elif ctype == "bool": | ||||
|             x = bool(int(x)) | ||||
|         elif ctype == "float": | ||||
|             x = float(x) | ||||
|         elif ctype == "none": | ||||
|             if x.lower() != "none": | ||||
|                 raise ValueError( | ||||
|                     "For the none type, the value must be none instead of {:}".format(x) | ||||
|                 ) | ||||
|             x = None | ||||
|         else: | ||||
|             raise TypeError("Does not know this type : {:}".format(ctype)) | ||||
|         outs.append(x) | ||||
|     if not is_list: | ||||
|         outs = outs[0] | ||||
|     return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|     path = str(path) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log(path) | ||||
|     assert os.path.exists(path), "Can not find {:}".format(path) | ||||
|     # Reading data back | ||||
|     with open(path, "r") as f: | ||||
|         data = json.load(f) | ||||
|     content = {k: convert_param(v) for k, v in data.items()} | ||||
|     assert extra is None or isinstance( | ||||
|         extra, dict | ||||
|     ), "invalid type of extra : {:}".format(extra) | ||||
|     if isinstance(extra, dict): | ||||
|         content = {**content, **extra} | ||||
|     Arguments = namedtuple("Configure", " ".join(content.keys())) | ||||
|     content = Arguments(**content) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|     if not isinstance(config, dict): | ||||
|         config = config._asdict() | ||||
|  | ||||
|     def cstring(x): | ||||
|         return '"{:}"'.format(x) | ||||
|  | ||||
|     def gtype(x): | ||||
|         if isinstance(x, list): | ||||
|             x = x[0] | ||||
|         if isinstance(x, str): | ||||
|             return "str" | ||||
|         elif isinstance(x, bool): | ||||
|             return "bool" | ||||
|         elif isinstance(x, int): | ||||
|             return "int" | ||||
|         elif isinstance(x, float): | ||||
|             return "float" | ||||
|         elif x is None: | ||||
|             return "none" | ||||
|         else: | ||||
|             raise ValueError("invalid : {:}".format(x)) | ||||
|  | ||||
|     def cvalue(x, xtype): | ||||
|         if isinstance(x, list): | ||||
|             is_list = True | ||||
|         else: | ||||
|             is_list, x = False, [x] | ||||
|         temps = [] | ||||
|         for temp in x: | ||||
|             if xtype == "bool": | ||||
|                 temp = cstring(int(temp)) | ||||
|             elif xtype == "none": | ||||
|                 temp = cstring("None") | ||||
|             else: | ||||
|                 temp = cstring(temp) | ||||
|             temps.append(temp) | ||||
|         if is_list: | ||||
|             return "[{:}]".format(", ".join(temps)) | ||||
|         else: | ||||
|             return temps[0] | ||||
|  | ||||
|     xstrings = [] | ||||
|     for key, value in config.items(): | ||||
|         xtype = gtype(value) | ||||
|         string = "  {:20s} : [{:8s}, {:}]".format( | ||||
|             cstring(key), cstring(xtype), cvalue(value, xtype) | ||||
|         ) | ||||
|         xstrings.append(string) | ||||
|     Fstring = "{\n" + ",\n".join(xstrings) + "\n}" | ||||
|     if xpath is not None: | ||||
|         parent = Path(xpath).resolve().parent | ||||
|         parent.mkdir(parents=True, exist_ok=True) | ||||
|         if osp.isfile(xpath): | ||||
|             os.remove(xpath) | ||||
|         with open(xpath, "w") as text_file: | ||||
|             text_file.write("{:}".format(Fstring)) | ||||
|     return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|     assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) | ||||
|     Arguments = namedtuple("Configure", " ".join(xdict.keys())) | ||||
|     content = Arguments(**xdict) | ||||
|     if hasattr(logger, "log"): | ||||
|         logger.log("{:}".format(content)) | ||||
|     return content | ||||
| @@ -1,106 +0,0 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| import os, json | ||||
| from os import path as osp | ||||
| from pathlib import Path | ||||
| from collections import namedtuple | ||||
|  | ||||
| support_types = ('str', 'int', 'bool', 'float', 'none') | ||||
|  | ||||
|  | ||||
| def convert_param(original_lists): | ||||
|   assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) | ||||
|   ctype, value = original_lists[0], original_lists[1] | ||||
|   assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) | ||||
|   is_list = isinstance(value, list) | ||||
|   if not is_list: value = [value] | ||||
|   outs = [] | ||||
|   for x in value: | ||||
|     if ctype == 'int': | ||||
|       x = int(x) | ||||
|     elif ctype == 'str': | ||||
|       x = str(x) | ||||
|     elif ctype == 'bool': | ||||
|       x = bool(int(x)) | ||||
|     elif ctype == 'float': | ||||
|       x = float(x) | ||||
|     elif ctype == 'none': | ||||
|       if x.lower() != 'none': | ||||
|         raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) | ||||
|       x = None | ||||
|     else: | ||||
|       raise TypeError('Does not know this type : {:}'.format(ctype)) | ||||
|     outs.append(x) | ||||
|   if not is_list: outs = outs[0] | ||||
|   return outs | ||||
|  | ||||
|  | ||||
| def load_config(path, extra, logger): | ||||
|   path = str(path) | ||||
|   if hasattr(logger, 'log'): logger.log(path) | ||||
|   assert os.path.exists(path), 'Can not find {:}'.format(path) | ||||
|   # Reading data back | ||||
|   with open(path, 'r') as f: | ||||
|     data = json.load(f) | ||||
|   content = { k: convert_param(v) for k,v in data.items()} | ||||
|   assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) | ||||
|   if isinstance(extra, dict): content = {**content, **extra} | ||||
|   Arguments = namedtuple('Configure', ' '.join(content.keys())) | ||||
|   content   = Arguments(**content) | ||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) | ||||
|   return content | ||||
|  | ||||
|  | ||||
| def configure2str(config, xpath=None): | ||||
|   if not isinstance(config, dict): | ||||
|     config = config._asdict() | ||||
|   def cstring(x): | ||||
|     return "\"{:}\"".format(x) | ||||
|   def gtype(x): | ||||
|     if isinstance(x, list): x = x[0] | ||||
|     if isinstance(x, str)  : return 'str' | ||||
|     elif isinstance(x, bool) : return 'bool' | ||||
|     elif isinstance(x, int): return 'int' | ||||
|     elif isinstance(x, float): return 'float' | ||||
|     elif x is None           : return 'none' | ||||
|     else: raise ValueError('invalid : {:}'.format(x)) | ||||
|   def cvalue(x, xtype): | ||||
|     if isinstance(x, list): is_list = True | ||||
|     else: | ||||
|       is_list, x = False, [x] | ||||
|     temps = [] | ||||
|     for temp in x: | ||||
|       if xtype == 'bool'  : temp = cstring(int(temp)) | ||||
|       elif xtype == 'none': temp = cstring('None') | ||||
|       else                : temp = cstring(temp) | ||||
|       temps.append( temp ) | ||||
|     if is_list: | ||||
|       return "[{:}]".format( ', '.join( temps ) ) | ||||
|     else: | ||||
|       return temps[0] | ||||
|  | ||||
|   xstrings = [] | ||||
|   for key, value in config.items(): | ||||
|     xtype  = gtype(value) | ||||
|     string = '  {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) | ||||
|     xstrings.append(string) | ||||
|   Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' | ||||
|   if xpath is not None: | ||||
|     parent = Path(xpath).resolve().parent | ||||
|     parent.mkdir(parents=True, exist_ok=True) | ||||
|     if osp.isfile(xpath): os.remove(xpath) | ||||
|     with open(xpath, "w") as text_file: | ||||
|       text_file.write('{:}'.format(Fstring)) | ||||
|   return Fstring | ||||
|  | ||||
|  | ||||
| def dict2config(xdict, logger): | ||||
|   assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) ) | ||||
|   Arguments = namedtuple('Configure', ' '.join(xdict.keys())) | ||||
|   content   = Arguments(**xdict) | ||||
|   if hasattr(logger, 'log'): logger.log('{:}'.format(content)) | ||||
|   return content | ||||
| @@ -1,26 +1,48 @@ | ||||
| import os, sys, time, random, argparse | ||||
| from .share_args import add_shared_args | ||||
|  | ||||
| def obtain_pruning_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--keep_ratio'  ,     type=float,                 help='The left channel ratio compared to the original network.') | ||||
|   parser.add_argument('--model_version',    type=str,                   help='The network version.') | ||||
|   parser.add_argument('--KD_alpha'    ,     type=float,                 help='The alpha parameter in knowledge distillation.') | ||||
|   parser.add_argument('--KD_temperature',   type=float,                 help='The temperature parameter in knowledge distillation.') | ||||
|   parser.add_argument('--Regular_W_feat',   type=float,                 help='The .') | ||||
|   parser.add_argument('--Regular_W_conv',   type=float,                 help='The .') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) | ||||
|   return args | ||||
| def obtain_pruning_args(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument( | ||||
|         "--keep_ratio", | ||||
|         type=float, | ||||
|         help="The left channel ratio compared to the original network.", | ||||
|     ) | ||||
|     parser.add_argument("--model_version", type=str, help="The network version.") | ||||
|     parser.add_argument( | ||||
|         "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--KD_temperature", | ||||
|         type=float, | ||||
|         help="The temperature parameter in knowledge distillation.", | ||||
|     ) | ||||
|     parser.add_argument("--Regular_W_feat", type=float, help="The .") | ||||
|     parser.add_argument("--Regular_W_conv", type=float, help="The .") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert ( | ||||
|         args.keep_ratio > 0 and args.keep_ratio <= 1 | ||||
|     ), "invalid keep ratio : {:}".format(args.keep_ratio) | ||||
|     return args | ||||
|   | ||||
| @@ -3,22 +3,42 @@ from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_RandomSearch_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'      ,     type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--init_model'  ,     type=str,                   help='The initialization model path.') | ||||
|   parser.add_argument('--expect_flop',      type=float,                 help='The expected flop keep ratio.') | ||||
|   parser.add_argument('--arch_nums'   ,     type=int,                   help='The maximum number of running random arch generating..') | ||||
|   parser.add_argument('--model_config',     type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size',       type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument("--init_model", type=str, help="The initialization model path.") | ||||
|     parser.add_argument( | ||||
|         "--expect_flop", type=float, help="The expected flop keep ratio." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nums", | ||||
|         type=int, | ||||
|         help="The maximum number of running random arch generating..", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_mode", | ||||
|         type=str, | ||||
|         choices=["random", "fix"], | ||||
|         help="The path to the optimizer configuration", | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   #assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|   return args | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) | ||||
|     return args | ||||
|   | ||||
| @@ -3,30 +3,51 @@ from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') | ||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') | ||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') | ||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') | ||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') | ||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') | ||||
|   # ablation studies | ||||
|   parser.add_argument('--ablation_num_select', type=int,                help='The number of randomly selected channels.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     # ablation studies | ||||
|     parser.add_argument( | ||||
|         "--ablation_num_select", | ||||
|         type=int, | ||||
|         help="The number of randomly selected channels.", | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) | ||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|   #args.arch_para_pure = bool(args.arch_para_pure) | ||||
|   return args | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
|   | ||||
| @@ -3,29 +3,46 @@ from .share_args import add_shared_args | ||||
|  | ||||
|  | ||||
| def obtain_search_single_args(): | ||||
|   parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--resume'        ,   type=str,                   help='Resume path.') | ||||
|   parser.add_argument('--model_config'  ,   type=str,                   help='The path to the model configuration') | ||||
|   parser.add_argument('--optim_config'  ,   type=str,                   help='The path to the optimizer configuration') | ||||
|   parser.add_argument('--split_path'    ,   type=str,                   help='The split file path.') | ||||
|   parser.add_argument('--search_shape'  ,   type=str,                   help='The shape to be searched.') | ||||
|   #parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|   parser.add_argument('--gumbel_tau_max',   type=float,                 help='The maximum tau for Gumbel.') | ||||
|   parser.add_argument('--gumbel_tau_min',   type=float,                 help='The minimum tau for Gumbel.') | ||||
|   parser.add_argument('--procedure'     ,   type=str,                   help='The procedure basic prefix.') | ||||
|   parser.add_argument('--FLOP_ratio'    ,   type=float,                 help='The expected FLOP ratio.') | ||||
|   parser.add_argument('--FLOP_weight'   ,   type=float,                 help='The loss weight for FLOP.') | ||||
|   parser.add_argument('--FLOP_tolerant' ,   type=float,                 help='The tolerant range for FLOP.') | ||||
|   add_shared_args( parser ) | ||||
|   # Optimization options | ||||
|   parser.add_argument('--batch_size'    ,   type=int,   default=2,      help='Batch size for training.') | ||||
|   args = parser.parse_args() | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Train a classification model on typical image classification datasets.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--resume", type=str, help="Resume path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", type=str, help="The path to the model configuration" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--optim_config", type=str, help="The path to the optimizer configuration" | ||||
|     ) | ||||
|     parser.add_argument("--split_path", type=str, help="The split file path.") | ||||
|     parser.add_argument("--search_shape", type=str, help="The shape to be searched.") | ||||
|     # parser.add_argument('--arch_para_pure',   type=int,                   help='The architecture-parameter pure or not.') | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." | ||||
|     ) | ||||
|     parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") | ||||
|     parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") | ||||
|     parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") | ||||
|     parser.add_argument( | ||||
|         "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." | ||||
|     ) | ||||
|     add_shared_args(parser) | ||||
|     # Optimization options | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=2, help="Batch size for training." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|   if args.rand_seed is None or args.rand_seed < 0: | ||||
|     args.rand_seed = random.randint(1, 100000) | ||||
|   assert args.save_dir is not None, 'save-path argument can not be None' | ||||
|   assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|   assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) | ||||
|   #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|   #args.arch_para_pure = bool(args.arch_para_pure) | ||||
|   return args | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "save-path argument can not be None" | ||||
|     assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None | ||||
|     assert ( | ||||
|         args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 | ||||
|     ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) | ||||
|     # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) | ||||
|     # args.arch_para_pure = bool(args.arch_para_pure) | ||||
|     return args | ||||
|   | ||||
| @@ -1,17 +1,39 @@ | ||||
| import os, sys, time, random, argparse | ||||
|  | ||||
| def add_shared_args( parser ): | ||||
|   # Data Generation | ||||
|   parser.add_argument('--dataset',          type=str,                   help='The dataset name.') | ||||
|   parser.add_argument('--data_path',        type=str,                   help='The dataset name.') | ||||
|   parser.add_argument('--cutout_length',    type=int,                   help='The cutout length, negative means not use.') | ||||
|   # Printing | ||||
|   parser.add_argument('--print_freq',       type=int,   default=100,    help='print frequency (default: 200)') | ||||
|   parser.add_argument('--print_freq_eval',  type=int,   default=100,    help='print frequency (default: 200)') | ||||
|   # Checkpoints | ||||
|   parser.add_argument('--eval_frequency',   type=int,   default=1,      help='evaluation frequency (default: 200)') | ||||
|   parser.add_argument('--save_dir',         type=str,                   help='Folder to save checkpoints and log.') | ||||
|   # Acceleration | ||||
|   parser.add_argument('--workers',          type=int,   default=8,      help='number of data loading workers (default: 8)') | ||||
|   # Random Seed | ||||
|   parser.add_argument('--rand_seed',        type=int,   default=-1,     help='manual seed') | ||||
|  | ||||
| def add_shared_args(parser): | ||||
|     # Data Generation | ||||
|     parser.add_argument("--dataset", type=str, help="The dataset name.") | ||||
|     parser.add_argument("--data_path", type=str, help="The dataset name.") | ||||
|     parser.add_argument( | ||||
|         "--cutout_length", type=int, help="The cutout length, negative means not use." | ||||
|     ) | ||||
|     # Printing | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=100, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq_eval", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="print frequency (default: 200)", | ||||
|     ) | ||||
|     # Checkpoints | ||||
|     parser.add_argument( | ||||
|         "--eval_frequency", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="evaluation frequency (default: 200)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     # Acceleration | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 8)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user