Move str2bool to config_utils

This commit is contained in:
D-X-Y 2021-03-30 09:17:05 +00:00
parent 9fc2c991f5
commit c2270fd153
16 changed files with 519 additions and 305 deletions

View File

@ -36,6 +36,7 @@ jobs:
python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/spaces -l 88 --check --diff --verbose
python -m black ./lib/trade_models -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose
python -m black ./lib/procedures -l 88 --check --diff --verbose python -m black ./lib/procedures -l 88 --check --diff --verbose
python -m black ./lib/config_utils -l 88 --check --diff --verbose
- name: Test Search Space - name: Test Search Space
run: | run: |

View File

@ -15,7 +15,7 @@
# python exps/trading/baselines.py --alg TabNet # # python exps/trading/baselines.py --alg TabNet #
# # # #
# python exps/trading/baselines.py --alg Transformer# # python exps/trading/baselines.py --alg Transformer#
# python exps/trading/baselines.py --alg TSF # python exps/trading/baselines.py --alg TSF
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0 # python exps/trading/baselines.py --alg TSF-4x64-drop0_0
##################################################### #####################################################
import sys import sys
@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from config_utils import arg_str2bool
from procedures.q_exps import update_gpu from procedures.q_exps import update_gpu
from procedures.q_exps import update_market from procedures.q_exps import update_market
from procedures.q_exps import run_exp from procedures.q_exps import run_exp
@ -182,6 +183,12 @@ if __name__ == "__main__":
help="The market indicator.", help="The market indicator.",
) )
parser.add_argument("--times", type=int, default=5, help="The repeated run times.") parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
parser.add_argument(
"--shared_dataset",
type=arg_str2bool,
default=False,
help="Whether to share the dataset for all algorithms?",
)
parser.add_argument( parser.add_argument(
"--gpu", type=int, default=0, help="The GPU ID used for train / test." "--gpu", type=int, default=0, help="The GPU ID used for train / test."
) )
@ -189,9 +196,13 @@ if __name__ == "__main__":
"--alg", "--alg",
type=str, type=str,
choices=list(alg2configs.keys()), choices=list(alg2configs.keys()),
nargs="+",
required=True, required=True,
help="The algorithm name.", help="The algorithm name(s).",
) )
args = parser.parse_args() args = parser.parse_args()
main(args, alg2configs[args.alg]) if len(args.alg) == 1:
main(args, alg2configs[args.alg[0]])
else:
print("-")

View File

@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from config_utils import arg_str2bool
import qlib import qlib
from qlib.config import REG_CN from qlib.config import REG_CN
from qlib.workflow import R from qlib.workflow import R
@ -184,16 +185,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser("Show Results") parser = argparse.ArgumentParser("Show Results")
def str2bool(v):
if isinstance(v, bool):
return v
elif v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser.add_argument( parser.add_argument(
"--save_dir", "--save_dir",
type=str, type=str,
@ -203,7 +194,7 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=str2bool, type=arg_str2bool,
default=False, default=False,
help="Print detailed log information or not.", help="Print detailed log information or not.",
) )
@ -228,7 +219,7 @@ if __name__ == "__main__":
info_dict["heads"], info_dict["heads"],
info_dict["values"], info_dict["values"],
info_dict["names"], info_dict["names"],
space=14, space=18,
verbose=True, verbose=True,
sort_key=True, sort_key=True,
) )

View File

@ -1,13 +1,19 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
from .configure_utils import load_config, dict2config, configure2str # general config related functions
from .basic_args import obtain_basic_args from .config_utils import load_config, dict2config, configure2str
from .attention_args import obtain_attention_args # the args setting for different experiments
from .random_baseline import obtain_RandomSearch_args from .basic_args import obtain_basic_args
from .cls_kd_args import obtain_cls_kd_args from .attention_args import obtain_attention_args
from .cls_init_args import obtain_cls_init_args from .random_baseline import obtain_RandomSearch_args
from .cls_kd_args import obtain_cls_kd_args
from .cls_init_args import obtain_cls_init_args
from .search_single_args import obtain_search_single_args from .search_single_args import obtain_search_single_args
from .search_args import obtain_search_args from .search_args import obtain_search_args
# for network pruning # for network pruning
from .pruning_args import obtain_pruning_args from .pruning_args import obtain_pruning_args
# utils for args
from .args_utils import arg_str2bool

View File

@ -0,0 +1,12 @@
import argparse
def arg_str2bool(v):
if isinstance(v, bool):
return v
elif v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")

View File

@ -1,22 +1,32 @@
import random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_attention_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--att_channel' , type=int, help='.')
parser.add_argument('--att_spatial' , type=str, help='.')
parser.add_argument('--att_active' , type=str, help='.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: def obtain_attention_args():
args.rand_seed = random.randint(1, 100000) parser = argparse.ArgumentParser(
assert args.save_dir is not None, 'save-path argument can not be None' description="Train a classification model on typical image classification datasets.",
return args formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument("--att_channel", type=int, help=".")
parser.add_argument("--att_spatial", type=str, help=".")
parser.add_argument("--att_active", type=str, help=".")
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "save-path argument can not be None"
return args

View File

@ -4,21 +4,41 @@
import random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_basic_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: def obtain_basic_args():
args.rand_seed = random.randint(1, 100000) parser = argparse.ArgumentParser(
assert args.save_dir is not None, 'save-path argument can not be None' description="Train a classification model on typical image classification datasets.",
return args formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument(
"--model_source",
type=str,
default="normal",
help="The source of model defination.",
)
parser.add_argument(
"--extra_model_path",
type=str,
default=None,
help="The extra model ckp file (help to indicate the searched architecture).",
)
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "save-path argument can not be None"
return args

View File

@ -1,20 +1,32 @@
import random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_cls_init_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: def obtain_cls_init_args():
args.rand_seed = random.randint(1, 100000) parser = argparse.ArgumentParser(
assert args.save_dir is not None, 'save-path argument can not be None' description="Train a classification model on typical image classification datasets.",
return args formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument(
"--init_checkpoint", type=str, help="The checkpoint path to the initial model."
)
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "save-path argument can not be None"
return args

View File

@ -1,23 +1,43 @@
import random, argparse import random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_cls_kd_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
#parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: def obtain_cls_kd_args():
args.rand_seed = random.randint(1, 100000) parser = argparse.ArgumentParser(
assert args.save_dir is not None, 'save-path argument can not be None' description="Train a classification model on typical image classification datasets.",
return args formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument(
"--KD_checkpoint",
type=str,
help="The teacher checkpoint in knowledge distillation.",
)
parser.add_argument(
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
)
parser.add_argument(
"--KD_temperature",
type=float,
help="The temperature parameter in knowledge distillation.",
)
# parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "save-path argument can not be None"
return args

View File

@ -0,0 +1,135 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ("str", "int", "bool", "float", "none")
def convert_param(original_lists):
assert isinstance(original_lists, list), "The type is not right : {:}".format(
original_lists
)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list:
value = [value]
outs = []
for x in value:
if ctype == "int":
x = int(x)
elif ctype == "str":
x = str(x)
elif ctype == "bool":
x = bool(int(x))
elif ctype == "float":
x = float(x)
elif ctype == "none":
if x.lower() != "none":
raise ValueError(
"For the none type, the value must be none instead of {:}".format(x)
)
x = None
else:
raise TypeError("Does not know this type : {:}".format(ctype))
outs.append(x)
if not is_list:
outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, "log"):
logger.log(path)
assert os.path.exists(path), "Can not find {:}".format(path)
# Reading data back
with open(path, "r") as f:
data = json.load(f)
content = {k: convert_param(v) for k, v in data.items()}
assert extra is None or isinstance(
extra, dict
), "invalid type of extra : {:}".format(extra)
if isinstance(extra, dict):
content = {**content, **extra}
Arguments = namedtuple("Configure", " ".join(content.keys()))
content = Arguments(**content)
if hasattr(logger, "log"):
logger.log("{:}".format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return '"{:}"'.format(x)
def gtype(x):
if isinstance(x, list):
x = x[0]
if isinstance(x, str):
return "str"
elif isinstance(x, bool):
return "bool"
elif isinstance(x, int):
return "int"
elif isinstance(x, float):
return "float"
elif x is None:
return "none"
else:
raise ValueError("invalid : {:}".format(x))
def cvalue(x, xtype):
if isinstance(x, list):
is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == "bool":
temp = cstring(int(temp))
elif xtype == "none":
temp = cstring("None")
else:
temp = cstring(temp)
temps.append(temp)
if is_list:
return "[{:}]".format(", ".join(temps))
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = " {:20s} : [{:8s}, {:}]".format(
cstring(key), cstring(xtype), cvalue(value, xtype)
)
xstrings.append(string)
Fstring = "{\n" + ",\n".join(xstrings) + "\n}"
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath):
os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write("{:}".format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
Arguments = namedtuple("Configure", " ".join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, "log"):
logger.log("{:}".format(content))
return content

View File

@ -1,106 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

View File

@ -1,26 +1,48 @@
import os, sys, time, random, argparse import os, sys, time, random, argparse
from .share_args import add_shared_args from .share_args import add_shared_args
def obtain_pruning_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--resume' , type=str, help='Resume path.')
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.')
parser.add_argument('--model_version', type=str, help='The network version.')
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
parser.add_argument('--Regular_W_feat', type=float, help='The .')
parser.add_argument('--Regular_W_conv', type=float, help='The .')
add_shared_args( parser )
# Optimization options
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: def obtain_pruning_args():
args.rand_seed = random.randint(1, 100000) parser = argparse.ArgumentParser(
assert args.save_dir is not None, 'save-path argument can not be None' description="Train a classification model on typical image classification datasets.",
assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio) formatter_class=argparse.ArgumentDefaultsHelpFormatter,
return args )
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument(
"--keep_ratio",
type=float,
help="The left channel ratio compared to the original network.",
)
parser.add_argument("--model_version", type=str, help="The network version.")
parser.add_argument(
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
)
parser.add_argument(
"--KD_temperature",
type=float,
help="The temperature parameter in knowledge distillation.",
)
parser.add_argument("--Regular_W_feat", type=float, help="The .")
parser.add_argument("--Regular_W_conv", type=float, help="The .")
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "save-path argument can not be None"
assert (
args.keep_ratio > 0 and args.keep_ratio <= 1
), "invalid keep ratio : {:}".format(args.keep_ratio)
return args

View File

@ -3,22 +3,42 @@ from .share_args import add_shared_args
def obtain_RandomSearch_args(): def obtain_RandomSearch_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(
parser.add_argument('--resume' , type=str, help='Resume path.') description="Train a classification model on typical image classification datasets.",
parser.add_argument('--init_model' , type=str, help='The initialization model path.') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.') )
parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..') parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument('--model_config', type=str, help='The path to the model configuration') parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration') parser.add_argument(
parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration') "--expect_flop", type=float, help="The expected flop keep ratio."
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') )
add_shared_args( parser ) parser.add_argument(
# Optimization options "--arch_nums",
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.') type=int,
args = parser.parse_args() help="The maximum number of running random arch generating..",
)
parser.add_argument(
"--model_config", type=str, help="The path to the model configuration"
)
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer configuration"
)
parser.add_argument(
"--random_mode",
type=str,
choices=["random", "fix"],
help="The path to the optimizer configuration",
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None' assert args.save_dir is not None, "save-path argument can not be None"
#assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max) # assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
return args return args

View File

@ -3,30 +3,51 @@ from .share_args import add_shared_args
def obtain_search_args(): def obtain_search_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(
parser.add_argument('--resume' , type=str, help='Resume path.') description="Train a classification model on typical image classification datasets.",
parser.add_argument('--model_config' , type=str, help='The path to the model configuration') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') )
parser.add_argument('--split_path' , type=str, help='The split file path.') parser.add_argument("--resume", type=str, help="Resume path.")
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') parser.add_argument(
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') "--model_config", type=str, help="The path to the model configuration"
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') )
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') parser.add_argument(
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') "--optim_config", type=str, help="The path to the optimizer configuration"
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') )
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') parser.add_argument("--split_path", type=str, help="The split file path.")
# ablation studies # parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.') parser.add_argument(
add_shared_args( parser ) "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
# Optimization options )
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') parser.add_argument(
args = parser.parse_args() "--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
parser.add_argument(
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
)
# ablation studies
parser.add_argument(
"--ablation_num_select",
type=int,
help="The number of randomly selected channels.",
)
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None' assert args.save_dir is not None, "save-path argument can not be None"
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) assert (
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
#args.arch_para_pure = bool(args.arch_para_pure) ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
return args # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
# args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@ -3,29 +3,46 @@ from .share_args import add_shared_args
def obtain_search_single_args(): def obtain_search_single_args():
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(
parser.add_argument('--resume' , type=str, help='Resume path.') description="Train a classification model on typical image classification datasets.",
parser.add_argument('--model_config' , type=str, help='The path to the model configuration') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration') )
parser.add_argument('--split_path' , type=str, help='The split file path.') parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument('--search_shape' , type=str, help='The shape to be searched.') parser.add_argument(
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.') "--model_config", type=str, help="The path to the model configuration"
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.') )
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.') parser.add_argument(
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.') "--optim_config", type=str, help="The path to the optimizer configuration"
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.') )
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.') parser.add_argument("--split_path", type=str, help="The split file path.")
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.') parser.add_argument("--search_shape", type=str, help="The shape to be searched.")
add_shared_args( parser ) # parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
# Optimization options parser.add_argument(
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.') "--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
args = parser.parse_args() )
parser.add_argument(
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
)
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
parser.add_argument(
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
)
add_shared_args(parser)
# Optimization options
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training."
)
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, 'save-path argument can not be None' assert args.save_dir is not None, "save-path argument can not be None"
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant) assert (
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure) args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
#args.arch_para_pure = bool(args.arch_para_pure) ), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
return args # assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
# args.arch_para_pure = bool(args.arch_para_pure)
return args

View File

@ -1,17 +1,39 @@
import os, sys, time, random, argparse import os, sys, time, random, argparse
def add_shared_args( parser ):
# Data Generation def add_shared_args(parser):
parser.add_argument('--dataset', type=str, help='The dataset name.') # Data Generation
parser.add_argument('--data_path', type=str, help='The dataset name.') parser.add_argument("--dataset", type=str, help="The dataset name.")
parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.') parser.add_argument("--data_path", type=str, help="The dataset name.")
# Printing parser.add_argument(
parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)') "--cutout_length", type=int, help="The cutout length, negative means not use."
parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)') )
# Checkpoints # Printing
parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)') parser.add_argument(
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') "--print_freq", type=int, default=100, help="print frequency (default: 200)"
# Acceleration )
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') parser.add_argument(
# Random Seed "--print_freq_eval",
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') type=int,
default=100,
help="print frequency (default: 200)",
)
# Checkpoints
parser.add_argument(
"--eval_frequency",
type=int,
default=1,
help="evaluation frequency (default: 200)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
# Acceleration
parser.add_argument(
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 8)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")