Fix bugs in .to(cpu)
This commit is contained in:
parent
e5ec43e04a
commit
756218974f
1
.gitignore
vendored
1
.gitignore
vendored
@ -132,3 +132,4 @@ mlruns*
|
|||||||
outputs
|
outputs
|
||||||
|
|
||||||
pytest_cache
|
pytest_cache
|
||||||
|
*.pkl
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 0654ea06cf2e7e0ef7b778dfddb808559587a8c9
|
Subproject commit 64593c851b2ac0472d9b6bcc3eac6816d292827a
|
@ -7,6 +7,8 @@ import os
|
|||||||
import pprint
|
import pprint
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from log_utils import pickle_load
|
||||||
import qlib
|
import qlib
|
||||||
from qlib.utils import init_instance_by_config
|
from qlib.utils import init_instance_by_config
|
||||||
from qlib.workflow import R
|
from qlib.workflow import R
|
||||||
|
@ -143,6 +143,19 @@ class QuantTransformer(Model):
|
|||||||
device = "cpu"
|
device = "cpu"
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
# move the optimizer
|
||||||
|
for param in self.train_optimizer.state.values():
|
||||||
|
# Not sure there are any global tensors in the state dict
|
||||||
|
if isinstance(param, torch.Tensor):
|
||||||
|
param.data = param.data.to(device)
|
||||||
|
if param._grad is not None:
|
||||||
|
param._grad.data = param._grad.data.to(device)
|
||||||
|
elif isinstance(param, dict):
|
||||||
|
for subparam in param.values():
|
||||||
|
if isinstance(subparam, torch.Tensor):
|
||||||
|
subparam.data = subparam.data.to(device)
|
||||||
|
if subparam._grad is not None:
|
||||||
|
subparam._grad.data = subparam._grad.data.to(device)
|
||||||
|
|
||||||
def loss_fn(self, pred, label):
|
def loss_fn(self, pred, label):
|
||||||
mask = ~torch.isnan(label)
|
mask = ~torch.isnan(label)
|
||||||
|
Loading…
Reference in New Issue
Block a user