Fix bugs in .to(cpu)

This commit is contained in:
D-X-Y 2021-03-30 12:25:47 +00:00
parent e5ec43e04a
commit 756218974f
4 changed files with 17 additions and 1 deletions

1
.gitignore vendored
View File

@ -132,3 +132,4 @@ mlruns*
outputs
pytest_cache
*.pkl

@ -1 +1 @@
Subproject commit 0654ea06cf2e7e0ef7b778dfddb808559587a8c9
Subproject commit 64593c851b2ac0472d9b6bcc3eac6816d292827a

View File

@ -7,6 +7,8 @@ import os
import pprint
import logging
from copy import deepcopy
from log_utils import pickle_load
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R

View File

@ -143,6 +143,19 @@ class QuantTransformer(Model):
device = "cpu"
self.device = device
self.model.to(self.device)
# move the optimizer
for param in self.train_optimizer.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)