From c2270fd153c0261c06c3f54df2fff55febb33fa7 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 30 Mar 2021 09:17:05 +0000 Subject: [PATCH] Move str2bool to config_utils --- .github/workflows/basic_test.yml | 1 + exps/trading/baselines.py | 17 +++- exps/trading/organize_results.py | 15 +-- lib/config_utils/__init__.py | 22 ++-- lib/config_utils/args_utils.py | 12 +++ lib/config_utils/attention_args.py | 46 +++++---- lib/config_utils/basic_args.py | 54 ++++++---- lib/config_utils/cls_init_args.py | 44 +++++--- lib/config_utils/cls_kd_args.py | 58 +++++++---- lib/config_utils/config_utils.py | 135 +++++++++++++++++++++++++ lib/config_utils/configure_utils.py | 106 ------------------- lib/config_utils/pruning_args.py | 66 ++++++++---- lib/config_utils/random_baseline.py | 56 ++++++---- lib/config_utils/search_args.py | 73 ++++++++----- lib/config_utils/search_single_args.py | 67 +++++++----- lib/config_utils/share_args.py | 52 +++++++--- 16 files changed, 519 insertions(+), 305 deletions(-) create mode 100644 lib/config_utils/args_utils.py create mode 100644 lib/config_utils/config_utils.py delete mode 100644 lib/config_utils/configure_utils.py diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 07b8a45..57751b8 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -36,6 +36,7 @@ jobs: python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose python -m black ./lib/procedures -l 88 --check --diff --verbose + python -m black ./lib/config_utils -l 88 --check --diff --verbose - name: Test Search Space run: | diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index e81ce01..59adb5d 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -15,7 +15,7 @@ # python exps/trading/baselines.py --alg TabNet # # # # python exps/trading/baselines.py --alg Transformer# -# python exps/trading/baselines.py --alg TSF +# python exps/trading/baselines.py --alg TSF # python exps/trading/baselines.py --alg TSF-4x64-drop0_0 ##################################################### import sys @@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import arg_str2bool from procedures.q_exps import update_gpu from procedures.q_exps import update_market from procedures.q_exps import run_exp @@ -182,6 +183,12 @@ if __name__ == "__main__": help="The market indicator.", ) parser.add_argument("--times", type=int, default=5, help="The repeated run times.") + parser.add_argument( + "--shared_dataset", + type=arg_str2bool, + default=False, + help="Whether to share the dataset for all algorithms?", + ) parser.add_argument( "--gpu", type=int, default=0, help="The GPU ID used for train / test." ) @@ -189,9 +196,13 @@ if __name__ == "__main__": "--alg", type=str, choices=list(alg2configs.keys()), + nargs="+", required=True, - help="The algorithm name.", + help="The algorithm name(s).", ) args = parser.parse_args() - main(args, alg2configs[args.alg]) + if len(args.alg) == 1: + main(args, alg2configs[args.alg[0]]) + else: + print("-") diff --git a/exps/trading/organize_results.py b/exps/trading/organize_results.py index f810057..03eeebb 100644 --- a/exps/trading/organize_results.py +++ b/exps/trading/organize_results.py @@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import arg_str2bool import qlib from qlib.config import REG_CN from qlib.workflow import R @@ -184,16 +185,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Show Results") - def str2bool(v): - if isinstance(v, bool): - return v - elif v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - parser.add_argument( "--save_dir", type=str, @@ -203,7 +194,7 @@ if __name__ == "__main__": ) parser.add_argument( "--verbose", - type=str2bool, + type=arg_str2bool, default=False, help="Print detailed log information or not.", ) @@ -228,7 +219,7 @@ if __name__ == "__main__": info_dict["heads"], info_dict["values"], info_dict["names"], - space=14, + space=18, verbose=True, sort_key=True, ) diff --git a/lib/config_utils/__init__.py b/lib/config_utils/__init__.py index dd91409..85a162d 100644 --- a/lib/config_utils/__init__.py +++ b/lib/config_utils/__init__.py @@ -1,13 +1,19 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## -from .configure_utils import load_config, dict2config, configure2str -from .basic_args import obtain_basic_args -from .attention_args import obtain_attention_args -from .random_baseline import obtain_RandomSearch_args -from .cls_kd_args import obtain_cls_kd_args -from .cls_init_args import obtain_cls_init_args +# general config related functions +from .config_utils import load_config, dict2config, configure2str +# the args setting for different experiments +from .basic_args import obtain_basic_args +from .attention_args import obtain_attention_args +from .random_baseline import obtain_RandomSearch_args +from .cls_kd_args import obtain_cls_kd_args +from .cls_init_args import obtain_cls_init_args from .search_single_args import obtain_search_single_args -from .search_args import obtain_search_args +from .search_args import obtain_search_args + # for network pruning -from .pruning_args import obtain_pruning_args +from .pruning_args import obtain_pruning_args + +# utils for args +from .args_utils import arg_str2bool diff --git a/lib/config_utils/args_utils.py b/lib/config_utils/args_utils.py new file mode 100644 index 0000000..f58e475 --- /dev/null +++ b/lib/config_utils/args_utils.py @@ -0,0 +1,12 @@ +import argparse + + +def arg_str2bool(v): + if isinstance(v, bool): + return v + elif v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/lib/config_utils/attention_args.py b/lib/config_utils/attention_args.py index 1f95b93..f5876ac 100644 --- a/lib/config_utils/attention_args.py +++ b/lib/config_utils/attention_args.py @@ -1,22 +1,32 @@ import random, argparse from .share_args import add_shared_args -def obtain_attention_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--att_channel' , type=int, help='.') - parser.add_argument('--att_spatial' , type=str, help='.') - parser.add_argument('--att_active' , type=str, help='.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - return args +def obtain_attention_args(): + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument("--att_channel", type=int, help=".") + parser.add_argument("--att_spatial", type=str, help=".") + parser.add_argument("--att_active", type=str, help=".") + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() + + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + return args diff --git a/lib/config_utils/basic_args.py b/lib/config_utils/basic_args.py index e89c86e..21c18b6 100644 --- a/lib/config_utils/basic_args.py +++ b/lib/config_utils/basic_args.py @@ -4,21 +4,41 @@ import random, argparse from .share_args import add_shared_args -def obtain_basic_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.') - parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - return args +def obtain_basic_args(): + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument( + "--model_source", + type=str, + default="normal", + help="The source of model defination.", + ) + parser.add_argument( + "--extra_model_path", + type=str, + default=None, + help="The extra model ckp file (help to indicate the searched architecture).", + ) + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() + + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + return args diff --git a/lib/config_utils/cls_init_args.py b/lib/config_utils/cls_init_args.py index 32e3125..96c5bb9 100644 --- a/lib/config_utils/cls_init_args.py +++ b/lib/config_utils/cls_init_args.py @@ -1,20 +1,32 @@ import random, argparse from .share_args import add_shared_args -def obtain_cls_init_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - return args +def obtain_cls_init_args(): + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument( + "--init_checkpoint", type=str, help="The checkpoint path to the initial model." + ) + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() + + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + return args diff --git a/lib/config_utils/cls_kd_args.py b/lib/config_utils/cls_kd_args.py index 4020510..03f208a 100644 --- a/lib/config_utils/cls_kd_args.py +++ b/lib/config_utils/cls_kd_args.py @@ -1,23 +1,43 @@ import random, argparse from .share_args import add_shared_args -def obtain_cls_kd_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.') - parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.') - parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.') - #parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - return args +def obtain_cls_kd_args(): + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument( + "--KD_checkpoint", + type=str, + help="The teacher checkpoint in knowledge distillation.", + ) + parser.add_argument( + "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." + ) + parser.add_argument( + "--KD_temperature", + type=float, + help="The temperature parameter in knowledge distillation.", + ) + # parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.') + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() + + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + return args diff --git a/lib/config_utils/config_utils.py b/lib/config_utils/config_utils.py new file mode 100644 index 0000000..733ecc0 --- /dev/null +++ b/lib/config_utils/config_utils.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +import os, json +from os import path as osp +from pathlib import Path +from collections import namedtuple + +support_types = ("str", "int", "bool", "float", "none") + + +def convert_param(original_lists): + assert isinstance(original_lists, list), "The type is not right : {:}".format( + original_lists + ) + ctype, value = original_lists[0], original_lists[1] + assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types) + is_list = isinstance(value, list) + if not is_list: + value = [value] + outs = [] + for x in value: + if ctype == "int": + x = int(x) + elif ctype == "str": + x = str(x) + elif ctype == "bool": + x = bool(int(x)) + elif ctype == "float": + x = float(x) + elif ctype == "none": + if x.lower() != "none": + raise ValueError( + "For the none type, the value must be none instead of {:}".format(x) + ) + x = None + else: + raise TypeError("Does not know this type : {:}".format(ctype)) + outs.append(x) + if not is_list: + outs = outs[0] + return outs + + +def load_config(path, extra, logger): + path = str(path) + if hasattr(logger, "log"): + logger.log(path) + assert os.path.exists(path), "Can not find {:}".format(path) + # Reading data back + with open(path, "r") as f: + data = json.load(f) + content = {k: convert_param(v) for k, v in data.items()} + assert extra is None or isinstance( + extra, dict + ), "invalid type of extra : {:}".format(extra) + if isinstance(extra, dict): + content = {**content, **extra} + Arguments = namedtuple("Configure", " ".join(content.keys())) + content = Arguments(**content) + if hasattr(logger, "log"): + logger.log("{:}".format(content)) + return content + + +def configure2str(config, xpath=None): + if not isinstance(config, dict): + config = config._asdict() + + def cstring(x): + return '"{:}"'.format(x) + + def gtype(x): + if isinstance(x, list): + x = x[0] + if isinstance(x, str): + return "str" + elif isinstance(x, bool): + return "bool" + elif isinstance(x, int): + return "int" + elif isinstance(x, float): + return "float" + elif x is None: + return "none" + else: + raise ValueError("invalid : {:}".format(x)) + + def cvalue(x, xtype): + if isinstance(x, list): + is_list = True + else: + is_list, x = False, [x] + temps = [] + for temp in x: + if xtype == "bool": + temp = cstring(int(temp)) + elif xtype == "none": + temp = cstring("None") + else: + temp = cstring(temp) + temps.append(temp) + if is_list: + return "[{:}]".format(", ".join(temps)) + else: + return temps[0] + + xstrings = [] + for key, value in config.items(): + xtype = gtype(value) + string = " {:20s} : [{:8s}, {:}]".format( + cstring(key), cstring(xtype), cvalue(value, xtype) + ) + xstrings.append(string) + Fstring = "{\n" + ",\n".join(xstrings) + "\n}" + if xpath is not None: + parent = Path(xpath).resolve().parent + parent.mkdir(parents=True, exist_ok=True) + if osp.isfile(xpath): + os.remove(xpath) + with open(xpath, "w") as text_file: + text_file.write("{:}".format(Fstring)) + return Fstring + + +def dict2config(xdict, logger): + assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict)) + Arguments = namedtuple("Configure", " ".join(xdict.keys())) + content = Arguments(**xdict) + if hasattr(logger, "log"): + logger.log("{:}".format(content)) + return content diff --git a/lib/config_utils/configure_utils.py b/lib/config_utils/configure_utils.py deleted file mode 100644 index 125e68e..0000000 --- a/lib/config_utils/configure_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -import os, json -from os import path as osp -from pathlib import Path -from collections import namedtuple - -support_types = ('str', 'int', 'bool', 'float', 'none') - - -def convert_param(original_lists): - assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) - ctype, value = original_lists[0], original_lists[1] - assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) - is_list = isinstance(value, list) - if not is_list: value = [value] - outs = [] - for x in value: - if ctype == 'int': - x = int(x) - elif ctype == 'str': - x = str(x) - elif ctype == 'bool': - x = bool(int(x)) - elif ctype == 'float': - x = float(x) - elif ctype == 'none': - if x.lower() != 'none': - raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) - x = None - else: - raise TypeError('Does not know this type : {:}'.format(ctype)) - outs.append(x) - if not is_list: outs = outs[0] - return outs - - -def load_config(path, extra, logger): - path = str(path) - if hasattr(logger, 'log'): logger.log(path) - assert os.path.exists(path), 'Can not find {:}'.format(path) - # Reading data back - with open(path, 'r') as f: - data = json.load(f) - content = { k: convert_param(v) for k,v in data.items()} - assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) - if isinstance(extra, dict): content = {**content, **extra} - Arguments = namedtuple('Configure', ' '.join(content.keys())) - content = Arguments(**content) - if hasattr(logger, 'log'): logger.log('{:}'.format(content)) - return content - - -def configure2str(config, xpath=None): - if not isinstance(config, dict): - config = config._asdict() - def cstring(x): - return "\"{:}\"".format(x) - def gtype(x): - if isinstance(x, list): x = x[0] - if isinstance(x, str) : return 'str' - elif isinstance(x, bool) : return 'bool' - elif isinstance(x, int): return 'int' - elif isinstance(x, float): return 'float' - elif x is None : return 'none' - else: raise ValueError('invalid : {:}'.format(x)) - def cvalue(x, xtype): - if isinstance(x, list): is_list = True - else: - is_list, x = False, [x] - temps = [] - for temp in x: - if xtype == 'bool' : temp = cstring(int(temp)) - elif xtype == 'none': temp = cstring('None') - else : temp = cstring(temp) - temps.append( temp ) - if is_list: - return "[{:}]".format( ', '.join( temps ) ) - else: - return temps[0] - - xstrings = [] - for key, value in config.items(): - xtype = gtype(value) - string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype)) - xstrings.append(string) - Fstring = '{\n' + ',\n'.join(xstrings) + '\n}' - if xpath is not None: - parent = Path(xpath).resolve().parent - parent.mkdir(parents=True, exist_ok=True) - if osp.isfile(xpath): os.remove(xpath) - with open(xpath, "w") as text_file: - text_file.write('{:}'.format(Fstring)) - return Fstring - - -def dict2config(xdict, logger): - assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) ) - Arguments = namedtuple('Configure', ' '.join(xdict.keys())) - content = Arguments(**xdict) - if hasattr(logger, 'log'): logger.log('{:}'.format(content)) - return content diff --git a/lib/config_utils/pruning_args.py b/lib/config_utils/pruning_args.py index 7462a71..01d3504 100644 --- a/lib/config_utils/pruning_args.py +++ b/lib/config_utils/pruning_args.py @@ -1,26 +1,48 @@ import os, sys, time, random, argparse from .share_args import add_shared_args -def obtain_pruning_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.') - parser.add_argument('--model_version', type=str, help='The network version.') - parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.') - parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.') - parser.add_argument('--Regular_W_feat', type=float, help='The .') - parser.add_argument('--Regular_W_conv', type=float, help='The .') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) - return args +def obtain_pruning_args(): + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument( + "--keep_ratio", + type=float, + help="The left channel ratio compared to the original network.", + ) + parser.add_argument("--model_version", type=str, help="The network version.") + parser.add_argument( + "--KD_alpha", type=float, help="The alpha parameter in knowledge distillation." + ) + parser.add_argument( + "--KD_temperature", + type=float, + help="The temperature parameter in knowledge distillation.", + ) + parser.add_argument("--Regular_W_feat", type=float, help="The .") + parser.add_argument("--Regular_W_conv", type=float, help="The .") + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() + + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + assert ( + args.keep_ratio > 0 and args.keep_ratio <= 1 + ), "invalid keep ratio : {:}".format(args.keep_ratio) + return args diff --git a/lib/config_utils/random_baseline.py b/lib/config_utils/random_baseline.py index 79b89c8..184da91 100644 --- a/lib/config_utils/random_baseline.py +++ b/lib/config_utils/random_baseline.py @@ -3,22 +3,42 @@ from .share_args import add_shared_args def obtain_RandomSearch_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--init_model' , type=str, help='The initialization model path.') - parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.') - parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..') - parser.add_argument('--model_config', type=str, help='The path to the model configuration') - parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') - parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument("--init_model", type=str, help="The initialization model path.") + parser.add_argument( + "--expect_flop", type=float, help="The expected flop keep ratio." + ) + parser.add_argument( + "--arch_nums", + type=int, + help="The maximum number of running random arch generating..", + ) + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument( + "--random_mode", + type=str, + choices=["random", "fix"], + help="The path to the optimizer configuration", + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - #assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) - return args + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) + return args diff --git a/lib/config_utils/search_args.py b/lib/config_utils/search_args.py index ecb60a1..2d278dc 100644 --- a/lib/config_utils/search_args.py +++ b/lib/config_utils/search_args.py @@ -3,30 +3,51 @@ from .share_args import add_shared_args def obtain_search_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--model_config' , type=str, help='The path to the model configuration') - parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') - parser.add_argument('--split_path' , type=str, help='The split file path.') - #parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') - parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') - parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') - parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') - parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') - # ablation studies - parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--split_path", type=str, help="The split file path.") + # parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') + parser.add_argument( + "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." + ) + parser.add_argument( + "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") + parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") + parser.add_argument( + "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." + ) + # ablation studies + parser.add_argument( + "--ablation_num_select", + type=int, + help="The number of randomly selected channels.", + ) + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None - assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) - #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) - #args.arch_para_pure = bool(args.arch_para_pure) - return args + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None + assert ( + args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 + ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) + # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) + # args.arch_para_pure = bool(args.arch_para_pure) + return args diff --git a/lib/config_utils/search_single_args.py b/lib/config_utils/search_single_args.py index 13e1ea6..6203b17 100644 --- a/lib/config_utils/search_single_args.py +++ b/lib/config_utils/search_single_args.py @@ -3,29 +3,46 @@ from .share_args import add_shared_args def obtain_search_single_args(): - parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--resume' , type=str, help='Resume path.') - parser.add_argument('--model_config' , type=str, help='The path to the model configuration') - parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') - parser.add_argument('--split_path' , type=str, help='The split file path.') - parser.add_argument('--search_shape' , type=str, help='The shape to be searched.') - #parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') - parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') - parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') - parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') - parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') - parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') - parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') - add_shared_args( parser ) - # Optimization options - parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Train a classification model on typical image classification datasets.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--resume", type=str, help="Resume path.") + parser.add_argument( + "--model_config", type=str, help="The path to the model configuration" + ) + parser.add_argument( + "--optim_config", type=str, help="The path to the optimizer configuration" + ) + parser.add_argument("--split_path", type=str, help="The split file path.") + parser.add_argument("--search_shape", type=str, help="The shape to be searched.") + # parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') + parser.add_argument( + "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel." + ) + parser.add_argument( + "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel." + ) + parser.add_argument("--procedure", type=str, help="The procedure basic prefix.") + parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.") + parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.") + parser.add_argument( + "--FLOP_tolerant", type=float, help="The tolerant range for FLOP." + ) + add_shared_args(parser) + # Optimization options + parser.add_argument( + "--batch_size", type=int, default=2, help="Batch size for training." + ) + args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, 'save-path argument can not be None' - assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None - assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) - #assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) - #args.arch_para_pure = bool(args.arch_para_pure) - return args + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "save-path argument can not be None" + assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None + assert ( + args.FLOP_tolerant is not None and args.FLOP_tolerant > 0 + ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant) + # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) + # args.arch_para_pure = bool(args.arch_para_pure) + return args diff --git a/lib/config_utils/share_args.py b/lib/config_utils/share_args.py index b582373..241696e 100644 --- a/lib/config_utils/share_args.py +++ b/lib/config_utils/share_args.py @@ -1,17 +1,39 @@ import os, sys, time, random, argparse -def add_shared_args( parser ): - # Data Generation - parser.add_argument('--dataset', type=str, help='The dataset name.') - parser.add_argument('--data_path', type=str, help='The dataset name.') - parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.') - # Printing - parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)') - parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)') - # Checkpoints - parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)') - parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') - # Acceleration - parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') - # Random Seed - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') + +def add_shared_args(parser): + # Data Generation + parser.add_argument("--dataset", type=str, help="The dataset name.") + parser.add_argument("--data_path", type=str, help="The dataset name.") + parser.add_argument( + "--cutout_length", type=int, help="The cutout length, negative means not use." + ) + # Printing + parser.add_argument( + "--print_freq", type=int, default=100, help="print frequency (default: 200)" + ) + parser.add_argument( + "--print_freq_eval", + type=int, + default=100, + help="print frequency (default: 200)", + ) + # Checkpoints + parser.add_argument( + "--eval_frequency", + type=int, + default=1, + help="evaluation frequency (default: 200)", + ) + parser.add_argument( + "--save_dir", type=str, help="Folder to save checkpoints and log." + ) + # Acceleration + parser.add_argument( + "--workers", + type=int, + default=8, + help="number of data loading workers (default: 8)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")