Fix bugs in update_gpu in procedures
This commit is contained in:
		| @@ -1,15 +1,16 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/trading/baselines.py --alg GRU | # python exps/trading/baselines.py --alg GRU        # | ||||||
| # python exps/trading/baselines.py --alg LSTM | # python exps/trading/baselines.py --alg LSTM       # | ||||||
| # python exps/trading/baselines.py --alg ALSTM | # python exps/trading/baselines.py --alg ALSTM      # | ||||||
| # python exps/trading/baselines.py --alg MLP | # python exps/trading/baselines.py --alg MLP        # | ||||||
| # python exps/trading/baselines.py --alg SFM | # python exps/trading/baselines.py --alg SFM        # | ||||||
| # python exps/trading/baselines.py --alg XGBoost | # python exps/trading/baselines.py --alg XGBoost    # | ||||||
| # python exps/trading/baselines.py --alg LightGBM | # python exps/trading/baselines.py --alg LightGBM   # | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, argparse | import sys | ||||||
|  | import argparse | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from pprint import pprint | from pprint import pprint | ||||||
| @@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path: | |||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from procedures.q_exps import update_gpu | from procedures.q_exps import update_gpu | ||||||
| from procedures.q_exps import update_market |  | ||||||
| from procedures.q_exps import run_exp | from procedures.q_exps import run_exp | ||||||
|  |  | ||||||
| import qlib | import qlib | ||||||
| @@ -64,7 +64,6 @@ def main(xargs, exp_yaml): | |||||||
|     with open(exp_yaml) as fp: |     with open(exp_yaml) as fp: | ||||||
|         config = yaml.safe_load(fp) |         config = yaml.safe_load(fp) | ||||||
|     config = update_gpu(config, xargs.gpu) |     config = update_gpu(config, xargs.gpu) | ||||||
|     # config = update_market(config, 'csi300') |  | ||||||
|  |  | ||||||
|     qlib.init(**config.get("qlib_init")) |     qlib.init(**config.get("qlib_init")) | ||||||
|     dataset_config = config.get("task").get("dataset") |     dataset_config = config.get("task").get("dataset") | ||||||
|   | |||||||
| @@ -12,10 +12,18 @@ from qlib.log import get_module_logger | |||||||
|  |  | ||||||
| def update_gpu(config, gpu): | def update_gpu(config, gpu): | ||||||
|     config = config.copy() |     config = config.copy() | ||||||
|     if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: |     if "task" in config and "model" in config["task"]: | ||||||
|  |         if "GPU" in config["task"]["model"]: | ||||||
|             config["task"]["model"]["GPU"] = gpu |             config["task"]["model"]["GPU"] = gpu | ||||||
|     elif "model" in config and "GPU" in config["model"]: |         elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: | ||||||
|  |             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||||
|  |     elif "model" in config: | ||||||
|  |         if "GPU" in config["model"]: | ||||||
|             config["model"]["GPU"] = gpu |             config["model"]["GPU"] = gpu | ||||||
|  |         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | ||||||
|  |             config["model"]["kwargs"]["GPU"] = gpu | ||||||
|  |     elif "kwargs" in config and "GPU" in config["kwargs"]: | ||||||
|  |         config["kwargs"]["GPU"] = gpu | ||||||
|     elif "GPU" in config: |     elif "GPU" in config: | ||||||
|         config["GPU"] = gpu |         config["GPU"] = gpu | ||||||
|     return config |     return config | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user