From 756218974f0f24558d046290178fdbc43083af8b Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 30 Mar 2021 12:25:47 +0000 Subject: [PATCH] Fix bugs in .to(cpu) --- .gitignore | 1 + .latent-data/NATS-Bench | 2 +- lib/procedures/q_exps.py | 2 ++ lib/trade_models/quant_transformer.py | 13 +++++++++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index a47c212..be684fb 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,4 @@ mlruns* outputs pytest_cache +*.pkl diff --git a/.latent-data/NATS-Bench b/.latent-data/NATS-Bench index 0654ea0..64593c8 160000 --- a/.latent-data/NATS-Bench +++ b/.latent-data/NATS-Bench @@ -1 +1 @@ -Subproject commit 0654ea06cf2e7e0ef7b778dfddb808559587a8c9 +Subproject commit 64593c851b2ac0472d9b6bcc3eac6816d292827a diff --git a/lib/procedures/q_exps.py b/lib/procedures/q_exps.py index 47c5a73..48a6b7e 100644 --- a/lib/procedures/q_exps.py +++ b/lib/procedures/q_exps.py @@ -7,6 +7,8 @@ import os import pprint import logging from copy import deepcopy + +from log_utils import pickle_load import qlib from qlib.utils import init_instance_by_config from qlib.workflow import R diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index c29e62b..ccd97d5 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -143,6 +143,19 @@ class QuantTransformer(Model): device = "cpu" self.device = device self.model.to(self.device) + # move the optimizer + for param in self.train_optimizer.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) def loss_fn(self, pred, label): mask = ~torch.isnan(label)