Move str2bool to config_utils
This commit is contained in:
		| @@ -15,7 +15,7 @@ | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg Transformer# | ||||
| # python exps/trading/baselines.py --alg TSF          | ||||
| # python exps/trading/baselines.py --alg TSF | ||||
| # python exps/trading/baselines.py --alg TSF-4x64-drop0_0 | ||||
| ##################################################### | ||||
| import sys | ||||
| @@ -30,6 +30,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from config_utils import arg_str2bool | ||||
| from procedures.q_exps import update_gpu | ||||
| from procedures.q_exps import update_market | ||||
| from procedures.q_exps import run_exp | ||||
| @@ -182,6 +183,12 @@ if __name__ == "__main__": | ||||
|         help="The market indicator.", | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") | ||||
|     parser.add_argument( | ||||
|         "--shared_dataset", | ||||
|         type=arg_str2bool, | ||||
|         default=False, | ||||
|         help="Whether to share the dataset for all algorithms?", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
| @@ -189,9 +196,13 @@ if __name__ == "__main__": | ||||
|         "--alg", | ||||
|         type=str, | ||||
|         choices=list(alg2configs.keys()), | ||||
|         nargs="+", | ||||
|         required=True, | ||||
|         help="The algorithm name.", | ||||
|         help="The algorithm name(s).", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args, alg2configs[args.alg]) | ||||
|     if len(args.alg) == 1: | ||||
|         main(args, alg2configs[args.alg[0]]) | ||||
|     else: | ||||
|         print("-") | ||||
|   | ||||
| @@ -15,6 +15,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from config_utils import arg_str2bool | ||||
| import qlib | ||||
| from qlib.config import REG_CN | ||||
| from qlib.workflow import R | ||||
| @@ -184,16 +185,6 @@ if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Show Results") | ||||
|  | ||||
|     def str2bool(v): | ||||
|         if isinstance(v, bool): | ||||
|             return v | ||||
|         elif v.lower() in ("yes", "true", "t", "y", "1"): | ||||
|             return True | ||||
|         elif v.lower() in ("no", "false", "f", "n", "0"): | ||||
|             return False | ||||
|         else: | ||||
|             raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
| @@ -203,7 +194,7 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--verbose", | ||||
|         type=str2bool, | ||||
|         type=arg_str2bool, | ||||
|         default=False, | ||||
|         help="Print detailed log information or not.", | ||||
|     ) | ||||
| @@ -228,7 +219,7 @@ if __name__ == "__main__": | ||||
|         info_dict["heads"], | ||||
|         info_dict["values"], | ||||
|         info_dict["names"], | ||||
|         space=14, | ||||
|         space=18, | ||||
|         verbose=True, | ||||
|         sort_key=True, | ||||
|     ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user