Fix bugs in update_gpu in procedures
This commit is contained in:
		| @@ -12,10 +12,18 @@ from qlib.log import get_module_logger | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: | ||||
|         config["task"]["model"]["GPU"] = gpu | ||||
|     elif "model" in config and "GPU" in config["model"]: | ||||
|         config["model"]["GPU"] = gpu | ||||
|     if "task" in config and "model" in config["task"]: | ||||
|         if "GPU" in config["task"]["model"]: | ||||
|             config["task"]["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: | ||||
|             config["task"]["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "model" in config: | ||||
|         if "GPU" in config["model"]: | ||||
|             config["model"]["GPU"] = gpu | ||||
|         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | ||||
|             config["model"]["kwargs"]["GPU"] = gpu | ||||
|     elif "kwargs" in config and "GPU" in config["kwargs"]: | ||||
|         config["kwargs"]["GPU"] = gpu | ||||
|     elif "GPU" in config: | ||||
|         config["GPU"] = gpu | ||||
|     return config | ||||
|   | ||||
		Reference in New Issue
	
	Block a user