Fix bugs in update_gpu in procedures
This commit is contained in:
		| @@ -1,15 +1,16 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| # python exps/trading/baselines.py --alg GRU | ||||
| # python exps/trading/baselines.py --alg LSTM | ||||
| # python exps/trading/baselines.py --alg ALSTM | ||||
| # python exps/trading/baselines.py --alg MLP | ||||
| # python exps/trading/baselines.py --alg SFM | ||||
| # python exps/trading/baselines.py --alg XGBoost | ||||
| # python exps/trading/baselines.py --alg LightGBM | ||||
| # python exps/trading/baselines.py --alg GRU        # | ||||
| # python exps/trading/baselines.py --alg LSTM       # | ||||
| # python exps/trading/baselines.py --alg ALSTM      # | ||||
| # python exps/trading/baselines.py --alg MLP        # | ||||
| # python exps/trading/baselines.py --alg SFM        # | ||||
| # python exps/trading/baselines.py --alg XGBoost    # | ||||
| # python exps/trading/baselines.py --alg LightGBM   # | ||||
| ##################################################### | ||||
| import sys, argparse | ||||
| import sys | ||||
| import argparse | ||||
| from collections import OrderedDict | ||||
| from pathlib import Path | ||||
| from pprint import pprint | ||||
| @@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from procedures.q_exps import update_gpu | ||||
| from procedures.q_exps import update_market | ||||
| from procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| @@ -64,7 +64,6 @@ def main(xargs, exp_yaml): | ||||
|     with open(exp_yaml) as fp: | ||||
|         config = yaml.safe_load(fp) | ||||
|     config = update_gpu(config, xargs.gpu) | ||||
|     # config = update_market(config, 'csi300') | ||||
|  | ||||
|     qlib.init(**config.get("qlib_init")) | ||||
|     dataset_config = config.get("task").get("dataset") | ||||
|   | ||||
| @@ -12,10 +12,18 @@ from qlib.log import get_module_logger | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: | ||||
|         config["task"]["model"]["GPU"] = gpu | ||||
|     elif "model" in config and "GPU" in config["model"]: | ||||
|         config["model"]["GPU"] = gpu | ||||
|     if "task" in config and "model" in config["task"]: | ||||
|         if "GPU" in config["task"]["model"]: | ||||
|             config["task"]["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: | ||||
|             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "model" in config: | ||||
|         if "GPU" in config["model"]: | ||||
|             config["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | ||||
|             config["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "kwargs" in config and "GPU" in config["kwargs"]: | ||||
|         config["kwargs"]["GPU"] = gpu | ||||
|     elif "GPU" in config: | ||||
|         config["GPU"] = gpu | ||||
|     return config | ||||
|   | ||||
		Reference in New Issue
	
	Block a user