Move str2bool to config_utils
This commit is contained in:
parent
9fc2c991f5
commit
c2270fd153
1
.github/workflows/basic_test.yml
vendored
1
.github/workflows/basic_test.yml
vendored
@ -36,6 +36,7 @@ jobs:
|
|||||||
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/procedures -l 88 --check --diff --verbose
|
python -m black ./lib/procedures -l 88 --check --diff --verbose
|
||||||
|
python -m black ./lib/config_utils -l 88 --check --diff --verbose
|
||||||
|
|
||||||
- name: Test Search Space
|
- name: Test Search Space
|
||||||
run: |
|
run: |
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# python exps/trading/baselines.py --alg TabNet #
|
# python exps/trading/baselines.py --alg TabNet #
|
||||||
# #
|
# #
|
||||||
# python exps/trading/baselines.py --alg Transformer#
|
# python exps/trading/baselines.py --alg Transformer#
|
||||||
# python exps/trading/baselines.py --alg TSF
|
# python exps/trading/baselines.py --alg TSF
|
||||||
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys
|
import sys
|
||||||
@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
|||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
from config_utils import arg_str2bool
|
||||||
from procedures.q_exps import update_gpu
|
from procedures.q_exps import update_gpu
|
||||||
from procedures.q_exps import update_market
|
from procedures.q_exps import update_market
|
||||||
from procedures.q_exps import run_exp
|
from procedures.q_exps import run_exp
|
||||||
@ -182,6 +183,12 @@ if __name__ == "__main__":
|
|||||||
help="The market indicator.",
|
help="The market indicator.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
|
parser.add_argument("--times", type=int, default=5, help="The repeated run times.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--shared_dataset",
|
||||||
|
type=arg_str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to share the dataset for all algorithms?",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||||
)
|
)
|
||||||
@ -189,9 +196,13 @@ if __name__ == "__main__":
|
|||||||
"--alg",
|
"--alg",
|
||||||
type=str,
|
type=str,
|
||||||
choices=list(alg2configs.keys()),
|
choices=list(alg2configs.keys()),
|
||||||
|
nargs="+",
|
||||||
required=True,
|
required=True,
|
||||||
help="The algorithm name.",
|
help="The algorithm name(s).",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args, alg2configs[args.alg])
|
if len(args.alg) == 1:
|
||||||
|
main(args, alg2configs[args.alg[0]])
|
||||||
|
else:
|
||||||
|
print("-")
|
||||||
|
@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
|||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
from config_utils import arg_str2bool
|
||||||
import qlib
|
import qlib
|
||||||
from qlib.config import REG_CN
|
from qlib.config import REG_CN
|
||||||
from qlib.workflow import R
|
from qlib.workflow import R
|
||||||
@ -184,16 +185,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser("Show Results")
|
parser = argparse.ArgumentParser("Show Results")
|
||||||
|
|
||||||
def str2bool(v):
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return v
|
|
||||||
elif v.lower() in ("yes", "true", "t", "y", "1"):
|
|
||||||
return True
|
|
||||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_dir",
|
"--save_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -203,7 +194,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
type=str2bool,
|
type=arg_str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Print detailed log information or not.",
|
help="Print detailed log information or not.",
|
||||||
)
|
)
|
||||||
@ -228,7 +219,7 @@ if __name__ == "__main__":
|
|||||||
info_dict["heads"],
|
info_dict["heads"],
|
||||||
info_dict["values"],
|
info_dict["values"],
|
||||||
info_dict["names"],
|
info_dict["names"],
|
||||||
space=14,
|
space=18,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
sort_key=True,
|
sort_key=True,
|
||||||
)
|
)
|
||||||
|
@ -1,13 +1,19 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
from .configure_utils import load_config, dict2config, configure2str
|
# general config related functions
|
||||||
from .basic_args import obtain_basic_args
|
from .config_utils import load_config, dict2config, configure2str
|
||||||
from .attention_args import obtain_attention_args
|
# the args setting for different experiments
|
||||||
from .random_baseline import obtain_RandomSearch_args
|
from .basic_args import obtain_basic_args
|
||||||
from .cls_kd_args import obtain_cls_kd_args
|
from .attention_args import obtain_attention_args
|
||||||
from .cls_init_args import obtain_cls_init_args
|
from .random_baseline import obtain_RandomSearch_args
|
||||||
|
from .cls_kd_args import obtain_cls_kd_args
|
||||||
|
from .cls_init_args import obtain_cls_init_args
|
||||||
from .search_single_args import obtain_search_single_args
|
from .search_single_args import obtain_search_single_args
|
||||||
from .search_args import obtain_search_args
|
from .search_args import obtain_search_args
|
||||||
|
|
||||||
# for network pruning
|
# for network pruning
|
||||||
from .pruning_args import obtain_pruning_args
|
from .pruning_args import obtain_pruning_args
|
||||||
|
|
||||||
|
# utils for args
|
||||||
|
from .args_utils import arg_str2bool
|
||||||
|
12
lib/config_utils/args_utils.py
Normal file
12
lib/config_utils/args_utils.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def arg_str2bool(v):
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
elif v.lower() in ("yes", "true", "t", "y", "1"):
|
||||||
|
return True
|
||||||
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
@ -1,22 +1,32 @@
|
|||||||
import random, argparse
|
import random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
|
||||||
def obtain_attention_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
|
||||||
parser.add_argument('--att_channel' , type=int, help='.')
|
|
||||||
parser.add_argument('--att_spatial' , type=str, help='.')
|
|
||||||
parser.add_argument('--att_active' , type=str, help='.')
|
|
||||||
add_shared_args( parser )
|
|
||||||
# Optimization options
|
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
def obtain_attention_args():
|
||||||
args.rand_seed = random.randint(1, 100000)
|
parser = argparse.ArgumentParser(
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
description="Train a classification model on typical image classification datasets.",
|
||||||
return args
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument("--att_channel", type=int, help=".")
|
||||||
|
parser.add_argument("--att_spatial", type=str, help=".")
|
||||||
|
parser.add_argument("--att_active", type=str, help=".")
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
|
return args
|
||||||
|
@ -4,21 +4,41 @@
|
|||||||
import random, argparse
|
import random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
|
||||||
def obtain_basic_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
|
||||||
parser.add_argument('--model_source', type=str, default='normal',help='The source of model defination.')
|
|
||||||
parser.add_argument('--extra_model_path', type=str, default=None, help='The extra model ckp file (help to indicate the searched architecture).')
|
|
||||||
add_shared_args( parser )
|
|
||||||
# Optimization options
|
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
def obtain_basic_args():
|
||||||
args.rand_seed = random.randint(1, 100000)
|
parser = argparse.ArgumentParser(
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
description="Train a classification model on typical image classification datasets.",
|
||||||
return args
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_source",
|
||||||
|
type=str,
|
||||||
|
default="normal",
|
||||||
|
help="The source of model defination.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--extra_model_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The extra model ckp file (help to indicate the searched architecture).",
|
||||||
|
)
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
|
return args
|
||||||
|
@ -1,20 +1,32 @@
|
|||||||
import random, argparse
|
import random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
|
||||||
def obtain_cls_init_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
|
||||||
parser.add_argument('--init_checkpoint', type=str, help='The checkpoint path to the initial model.')
|
|
||||||
add_shared_args( parser )
|
|
||||||
# Optimization options
|
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
def obtain_cls_init_args():
|
||||||
args.rand_seed = random.randint(1, 100000)
|
parser = argparse.ArgumentParser(
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
description="Train a classification model on typical image classification datasets.",
|
||||||
return args
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--init_checkpoint", type=str, help="The checkpoint path to the initial model."
|
||||||
|
)
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
|
return args
|
||||||
|
@ -1,23 +1,43 @@
|
|||||||
import random, argparse
|
import random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
|
||||||
def obtain_cls_kd_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
|
||||||
parser.add_argument('--KD_checkpoint', type=str, help='The teacher checkpoint in knowledge distillation.')
|
|
||||||
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
|
|
||||||
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
|
|
||||||
#parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
|
|
||||||
add_shared_args( parser )
|
|
||||||
# Optimization options
|
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
def obtain_cls_kd_args():
|
||||||
args.rand_seed = random.randint(1, 100000)
|
parser = argparse.ArgumentParser(
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
description="Train a classification model on typical image classification datasets.",
|
||||||
return args
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--KD_checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="The teacher checkpoint in knowledge distillation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--KD_temperature",
|
||||||
|
type=float,
|
||||||
|
help="The temperature parameter in knowledge distillation.",
|
||||||
|
)
|
||||||
|
# parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
|
return args
|
||||||
|
135
lib/config_utils/config_utils.py
Normal file
135
lib/config_utils/config_utils.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
#
|
||||||
|
import os, json
|
||||||
|
from os import path as osp
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
support_types = ("str", "int", "bool", "float", "none")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_param(original_lists):
|
||||||
|
assert isinstance(original_lists, list), "The type is not right : {:}".format(
|
||||||
|
original_lists
|
||||||
|
)
|
||||||
|
ctype, value = original_lists[0], original_lists[1]
|
||||||
|
assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types)
|
||||||
|
is_list = isinstance(value, list)
|
||||||
|
if not is_list:
|
||||||
|
value = [value]
|
||||||
|
outs = []
|
||||||
|
for x in value:
|
||||||
|
if ctype == "int":
|
||||||
|
x = int(x)
|
||||||
|
elif ctype == "str":
|
||||||
|
x = str(x)
|
||||||
|
elif ctype == "bool":
|
||||||
|
x = bool(int(x))
|
||||||
|
elif ctype == "float":
|
||||||
|
x = float(x)
|
||||||
|
elif ctype == "none":
|
||||||
|
if x.lower() != "none":
|
||||||
|
raise ValueError(
|
||||||
|
"For the none type, the value must be none instead of {:}".format(x)
|
||||||
|
)
|
||||||
|
x = None
|
||||||
|
else:
|
||||||
|
raise TypeError("Does not know this type : {:}".format(ctype))
|
||||||
|
outs.append(x)
|
||||||
|
if not is_list:
|
||||||
|
outs = outs[0]
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path, extra, logger):
|
||||||
|
path = str(path)
|
||||||
|
if hasattr(logger, "log"):
|
||||||
|
logger.log(path)
|
||||||
|
assert os.path.exists(path), "Can not find {:}".format(path)
|
||||||
|
# Reading data back
|
||||||
|
with open(path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
content = {k: convert_param(v) for k, v in data.items()}
|
||||||
|
assert extra is None or isinstance(
|
||||||
|
extra, dict
|
||||||
|
), "invalid type of extra : {:}".format(extra)
|
||||||
|
if isinstance(extra, dict):
|
||||||
|
content = {**content, **extra}
|
||||||
|
Arguments = namedtuple("Configure", " ".join(content.keys()))
|
||||||
|
content = Arguments(**content)
|
||||||
|
if hasattr(logger, "log"):
|
||||||
|
logger.log("{:}".format(content))
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def configure2str(config, xpath=None):
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
config = config._asdict()
|
||||||
|
|
||||||
|
def cstring(x):
|
||||||
|
return '"{:}"'.format(x)
|
||||||
|
|
||||||
|
def gtype(x):
|
||||||
|
if isinstance(x, list):
|
||||||
|
x = x[0]
|
||||||
|
if isinstance(x, str):
|
||||||
|
return "str"
|
||||||
|
elif isinstance(x, bool):
|
||||||
|
return "bool"
|
||||||
|
elif isinstance(x, int):
|
||||||
|
return "int"
|
||||||
|
elif isinstance(x, float):
|
||||||
|
return "float"
|
||||||
|
elif x is None:
|
||||||
|
return "none"
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid : {:}".format(x))
|
||||||
|
|
||||||
|
def cvalue(x, xtype):
|
||||||
|
if isinstance(x, list):
|
||||||
|
is_list = True
|
||||||
|
else:
|
||||||
|
is_list, x = False, [x]
|
||||||
|
temps = []
|
||||||
|
for temp in x:
|
||||||
|
if xtype == "bool":
|
||||||
|
temp = cstring(int(temp))
|
||||||
|
elif xtype == "none":
|
||||||
|
temp = cstring("None")
|
||||||
|
else:
|
||||||
|
temp = cstring(temp)
|
||||||
|
temps.append(temp)
|
||||||
|
if is_list:
|
||||||
|
return "[{:}]".format(", ".join(temps))
|
||||||
|
else:
|
||||||
|
return temps[0]
|
||||||
|
|
||||||
|
xstrings = []
|
||||||
|
for key, value in config.items():
|
||||||
|
xtype = gtype(value)
|
||||||
|
string = " {:20s} : [{:8s}, {:}]".format(
|
||||||
|
cstring(key), cstring(xtype), cvalue(value, xtype)
|
||||||
|
)
|
||||||
|
xstrings.append(string)
|
||||||
|
Fstring = "{\n" + ",\n".join(xstrings) + "\n}"
|
||||||
|
if xpath is not None:
|
||||||
|
parent = Path(xpath).resolve().parent
|
||||||
|
parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if osp.isfile(xpath):
|
||||||
|
os.remove(xpath)
|
||||||
|
with open(xpath, "w") as text_file:
|
||||||
|
text_file.write("{:}".format(Fstring))
|
||||||
|
return Fstring
|
||||||
|
|
||||||
|
|
||||||
|
def dict2config(xdict, logger):
|
||||||
|
assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
|
||||||
|
Arguments = namedtuple("Configure", " ".join(xdict.keys()))
|
||||||
|
content = Arguments(**xdict)
|
||||||
|
if hasattr(logger, "log"):
|
||||||
|
logger.log("{:}".format(content))
|
||||||
|
return content
|
@ -1,106 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
#
|
|
||||||
import os, json
|
|
||||||
from os import path as osp
|
|
||||||
from pathlib import Path
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
support_types = ('str', 'int', 'bool', 'float', 'none')
|
|
||||||
|
|
||||||
|
|
||||||
def convert_param(original_lists):
|
|
||||||
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
|
|
||||||
ctype, value = original_lists[0], original_lists[1]
|
|
||||||
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
|
|
||||||
is_list = isinstance(value, list)
|
|
||||||
if not is_list: value = [value]
|
|
||||||
outs = []
|
|
||||||
for x in value:
|
|
||||||
if ctype == 'int':
|
|
||||||
x = int(x)
|
|
||||||
elif ctype == 'str':
|
|
||||||
x = str(x)
|
|
||||||
elif ctype == 'bool':
|
|
||||||
x = bool(int(x))
|
|
||||||
elif ctype == 'float':
|
|
||||||
x = float(x)
|
|
||||||
elif ctype == 'none':
|
|
||||||
if x.lower() != 'none':
|
|
||||||
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
|
|
||||||
x = None
|
|
||||||
else:
|
|
||||||
raise TypeError('Does not know this type : {:}'.format(ctype))
|
|
||||||
outs.append(x)
|
|
||||||
if not is_list: outs = outs[0]
|
|
||||||
return outs
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(path, extra, logger):
|
|
||||||
path = str(path)
|
|
||||||
if hasattr(logger, 'log'): logger.log(path)
|
|
||||||
assert os.path.exists(path), 'Can not find {:}'.format(path)
|
|
||||||
# Reading data back
|
|
||||||
with open(path, 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
content = { k: convert_param(v) for k,v in data.items()}
|
|
||||||
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
|
|
||||||
if isinstance(extra, dict): content = {**content, **extra}
|
|
||||||
Arguments = namedtuple('Configure', ' '.join(content.keys()))
|
|
||||||
content = Arguments(**content)
|
|
||||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def configure2str(config, xpath=None):
|
|
||||||
if not isinstance(config, dict):
|
|
||||||
config = config._asdict()
|
|
||||||
def cstring(x):
|
|
||||||
return "\"{:}\"".format(x)
|
|
||||||
def gtype(x):
|
|
||||||
if isinstance(x, list): x = x[0]
|
|
||||||
if isinstance(x, str) : return 'str'
|
|
||||||
elif isinstance(x, bool) : return 'bool'
|
|
||||||
elif isinstance(x, int): return 'int'
|
|
||||||
elif isinstance(x, float): return 'float'
|
|
||||||
elif x is None : return 'none'
|
|
||||||
else: raise ValueError('invalid : {:}'.format(x))
|
|
||||||
def cvalue(x, xtype):
|
|
||||||
if isinstance(x, list): is_list = True
|
|
||||||
else:
|
|
||||||
is_list, x = False, [x]
|
|
||||||
temps = []
|
|
||||||
for temp in x:
|
|
||||||
if xtype == 'bool' : temp = cstring(int(temp))
|
|
||||||
elif xtype == 'none': temp = cstring('None')
|
|
||||||
else : temp = cstring(temp)
|
|
||||||
temps.append( temp )
|
|
||||||
if is_list:
|
|
||||||
return "[{:}]".format( ', '.join( temps ) )
|
|
||||||
else:
|
|
||||||
return temps[0]
|
|
||||||
|
|
||||||
xstrings = []
|
|
||||||
for key, value in config.items():
|
|
||||||
xtype = gtype(value)
|
|
||||||
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
|
|
||||||
xstrings.append(string)
|
|
||||||
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
|
|
||||||
if xpath is not None:
|
|
||||||
parent = Path(xpath).resolve().parent
|
|
||||||
parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if osp.isfile(xpath): os.remove(xpath)
|
|
||||||
with open(xpath, "w") as text_file:
|
|
||||||
text_file.write('{:}'.format(Fstring))
|
|
||||||
return Fstring
|
|
||||||
|
|
||||||
|
|
||||||
def dict2config(xdict, logger):
|
|
||||||
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
|
|
||||||
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
|
|
||||||
content = Arguments(**xdict)
|
|
||||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
|
||||||
return content
|
|
@ -1,26 +1,48 @@
|
|||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
|
||||||
def obtain_pruning_args():
|
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
|
||||||
parser.add_argument('--keep_ratio' , type=float, help='The left channel ratio compared to the original network.')
|
|
||||||
parser.add_argument('--model_version', type=str, help='The network version.')
|
|
||||||
parser.add_argument('--KD_alpha' , type=float, help='The alpha parameter in knowledge distillation.')
|
|
||||||
parser.add_argument('--KD_temperature', type=float, help='The temperature parameter in knowledge distillation.')
|
|
||||||
parser.add_argument('--Regular_W_feat', type=float, help='The .')
|
|
||||||
parser.add_argument('--Regular_W_conv', type=float, help='The .')
|
|
||||||
add_shared_args( parser )
|
|
||||||
# Optimization options
|
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
def obtain_pruning_args():
|
||||||
args.rand_seed = random.randint(1, 100000)
|
parser = argparse.ArgumentParser(
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
description="Train a classification model on typical image classification datasets.",
|
||||||
assert args.keep_ratio > 0 and args.keep_ratio <= 1, 'invalid keep ratio : {:}'.format(args.keep_ratio)
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
return args
|
)
|
||||||
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep_ratio",
|
||||||
|
type=float,
|
||||||
|
help="The left channel ratio compared to the original network.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_version", type=str, help="The network version.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--KD_temperature",
|
||||||
|
type=float,
|
||||||
|
help="The temperature parameter in knowledge distillation.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--Regular_W_feat", type=float, help="The .")
|
||||||
|
parser.add_argument("--Regular_W_conv", type=float, help="The .")
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
|
assert (
|
||||||
|
args.keep_ratio > 0 and args.keep_ratio <= 1
|
||||||
|
), "invalid keep ratio : {:}".format(args.keep_ratio)
|
||||||
|
return args
|
||||||
|
@ -3,22 +3,42 @@ from .share_args import add_shared_args
|
|||||||
|
|
||||||
|
|
||||||
def obtain_RandomSearch_args():
|
def obtain_RandomSearch_args():
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
description="Train a classification model on typical image classification datasets.",
|
||||||
parser.add_argument('--init_model' , type=str, help='The initialization model path.')
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--expect_flop', type=float, help='The expected flop keep ratio.')
|
)
|
||||||
parser.add_argument('--arch_nums' , type=int, help='The maximum number of running random arch generating..')
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
parser.add_argument('--model_config', type=str, help='The path to the model configuration')
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
parser.add_argument('--optim_config', type=str, help='The path to the optimizer configuration')
|
parser.add_argument(
|
||||||
parser.add_argument('--random_mode', type=str, choices=['random', 'fix'], help='The path to the optimizer configuration')
|
"--expect_flop", type=float, help="The expected flop keep ratio."
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
)
|
||||||
add_shared_args( parser )
|
parser.add_argument(
|
||||||
# Optimization options
|
"--arch_nums",
|
||||||
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
|
type=int,
|
||||||
args = parser.parse_args()
|
help="The maximum number of running random arch generating..",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random_mode",
|
||||||
|
type=str,
|
||||||
|
choices=["random", "fix"],
|
||||||
|
help="The path to the optimizer configuration",
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
#assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
|
# assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
|
||||||
return args
|
return args
|
||||||
|
@ -3,30 +3,51 @@ from .share_args import add_shared_args
|
|||||||
|
|
||||||
|
|
||||||
def obtain_search_args():
|
def obtain_search_args():
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
description="Train a classification model on typical image classification datasets.",
|
||||||
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
|
)
|
||||||
parser.add_argument('--split_path' , type=str, help='The split file path.')
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
parser.add_argument(
|
||||||
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
|
)
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
parser.add_argument(
|
||||||
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
|
)
|
||||||
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
|
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||||
# ablation studies
|
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||||
parser.add_argument('--ablation_num_select', type=int, help='The number of randomly selected channels.')
|
parser.add_argument(
|
||||||
add_shared_args( parser )
|
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||||
# Optimization options
|
)
|
||||||
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
|
parser.add_argument(
|
||||||
args = parser.parse_args()
|
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||||
|
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||||
|
)
|
||||||
|
# ablation studies
|
||||||
|
parser.add_argument(
|
||||||
|
"--ablation_num_select",
|
||||||
|
type=int,
|
||||||
|
help="The number of randomly selected channels.",
|
||||||
|
)
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||||
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
|
assert (
|
||||||
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||||
#args.arch_para_pure = bool(args.arch_para_pure)
|
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||||
return args
|
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||||
|
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||||
|
return args
|
||||||
|
@ -3,29 +3,46 @@ from .share_args import add_shared_args
|
|||||||
|
|
||||||
|
|
||||||
def obtain_search_single_args():
|
def obtain_search_single_args():
|
||||||
parser = argparse.ArgumentParser(description='Train a classification model on typical image classification datasets.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--resume' , type=str, help='Resume path.')
|
description="Train a classification model on typical image classification datasets.",
|
||||||
parser.add_argument('--model_config' , type=str, help='The path to the model configuration')
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--optim_config' , type=str, help='The path to the optimizer configuration')
|
)
|
||||||
parser.add_argument('--split_path' , type=str, help='The split file path.')
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
parser.add_argument('--search_shape' , type=str, help='The shape to be searched.')
|
parser.add_argument(
|
||||||
#parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
"--model_config", type=str, help="The path to the model configuration"
|
||||||
parser.add_argument('--gumbel_tau_max', type=float, help='The maximum tau for Gumbel.')
|
)
|
||||||
parser.add_argument('--gumbel_tau_min', type=float, help='The minimum tau for Gumbel.')
|
parser.add_argument(
|
||||||
parser.add_argument('--procedure' , type=str, help='The procedure basic prefix.')
|
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||||
parser.add_argument('--FLOP_ratio' , type=float, help='The expected FLOP ratio.')
|
)
|
||||||
parser.add_argument('--FLOP_weight' , type=float, help='The loss weight for FLOP.')
|
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||||
parser.add_argument('--FLOP_tolerant' , type=float, help='The tolerant range for FLOP.')
|
parser.add_argument("--search_shape", type=str, help="The shape to be searched.")
|
||||||
add_shared_args( parser )
|
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||||
# Optimization options
|
parser.add_argument(
|
||||||
parser.add_argument('--batch_size' , type=int, default=2, help='Batch size for training.')
|
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||||
args = parser.parse_args()
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||||
|
)
|
||||||
|
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||||
|
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||||
|
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||||
|
)
|
||||||
|
add_shared_args(parser)
|
||||||
|
# Optimization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, 'save-path argument can not be None'
|
assert args.save_dir is not None, "save-path argument can not be None"
|
||||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||||
assert args.FLOP_tolerant is not None and args.FLOP_tolerant > 0, 'invalid FLOP_tolerant : {:}'.format(FLOP_tolerant)
|
assert (
|
||||||
#assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||||
#args.arch_para_pure = bool(args.arch_para_pure)
|
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||||
return args
|
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||||
|
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||||
|
return args
|
||||||
|
@ -1,17 +1,39 @@
|
|||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
|
|
||||||
def add_shared_args( parser ):
|
|
||||||
# Data Generation
|
def add_shared_args(parser):
|
||||||
parser.add_argument('--dataset', type=str, help='The dataset name.')
|
# Data Generation
|
||||||
parser.add_argument('--data_path', type=str, help='The dataset name.')
|
parser.add_argument("--dataset", type=str, help="The dataset name.")
|
||||||
parser.add_argument('--cutout_length', type=int, help='The cutout length, negative means not use.')
|
parser.add_argument("--data_path", type=str, help="The dataset name.")
|
||||||
# Printing
|
parser.add_argument(
|
||||||
parser.add_argument('--print_freq', type=int, default=100, help='print frequency (default: 200)')
|
"--cutout_length", type=int, help="The cutout length, negative means not use."
|
||||||
parser.add_argument('--print_freq_eval', type=int, default=100, help='print frequency (default: 200)')
|
)
|
||||||
# Checkpoints
|
# Printing
|
||||||
parser.add_argument('--eval_frequency', type=int, default=1, help='evaluation frequency (default: 200)')
|
parser.add_argument(
|
||||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
"--print_freq", type=int, default=100, help="print frequency (default: 200)"
|
||||||
# Acceleration
|
)
|
||||||
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)')
|
parser.add_argument(
|
||||||
# Random Seed
|
"--print_freq_eval",
|
||||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="print frequency (default: 200)",
|
||||||
|
)
|
||||||
|
# Checkpoints
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_frequency",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="evaluation frequency (default: 200)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||||
|
)
|
||||||
|
# Acceleration
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="number of data loading workers (default: 8)",
|
||||||
|
)
|
||||||
|
# Random Seed
|
||||||
|
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||||
|
Loading…
Reference in New Issue
Block a user