Fix bugs in .to(cpu)
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -132,3 +132,4 @@ mlruns* | |||||||
| outputs | outputs | ||||||
|  |  | ||||||
| pytest_cache | pytest_cache | ||||||
|  | *.pkl | ||||||
|   | |||||||
 Submodule .latent-data/NATS-Bench updated: 0654ea06cf...64593c851b
									
								
							| @@ -7,6 +7,8 @@ import os | |||||||
| import pprint | import pprint | ||||||
| import logging | import logging | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  |  | ||||||
|  | from log_utils import pickle_load | ||||||
| import qlib | import qlib | ||||||
| from qlib.utils import init_instance_by_config | from qlib.utils import init_instance_by_config | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
|   | |||||||
| @@ -143,6 +143,19 @@ class QuantTransformer(Model): | |||||||
|             device = "cpu" |             device = "cpu" | ||||||
|         self.device = device |         self.device = device | ||||||
|         self.model.to(self.device) |         self.model.to(self.device) | ||||||
|  |         # move the optimizer | ||||||
|  |         for param in self.train_optimizer.state.values(): | ||||||
|  |             # Not sure there are any global tensors in the state dict | ||||||
|  |             if isinstance(param, torch.Tensor): | ||||||
|  |                 param.data = param.data.to(device) | ||||||
|  |                 if param._grad is not None: | ||||||
|  |                     param._grad.data = param._grad.data.to(device) | ||||||
|  |             elif isinstance(param, dict): | ||||||
|  |                 for subparam in param.values(): | ||||||
|  |                     if isinstance(subparam, torch.Tensor): | ||||||
|  |                         subparam.data = subparam.data.to(device) | ||||||
|  |                         if subparam._grad is not None: | ||||||
|  |                             subparam._grad.data = subparam._grad.data.to(device) | ||||||
|  |  | ||||||
|     def loss_fn(self, pred, label): |     def loss_fn(self, pred, label): | ||||||
|         mask = ~torch.isnan(label) |         mask = ~torch.isnan(label) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user