Fix bugs in .to(cpu)
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -132,3 +132,4 @@ mlruns* | ||||
| outputs | ||||
|  | ||||
| pytest_cache | ||||
| *.pkl | ||||
|   | ||||
 Submodule .latent-data/NATS-Bench updated: 0654ea06cf...64593c851b
									
								
							| @@ -7,6 +7,8 @@ import os | ||||
| import pprint | ||||
| import logging | ||||
| from copy import deepcopy | ||||
|  | ||||
| from log_utils import pickle_load | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
|   | ||||
| @@ -143,6 +143,19 @@ class QuantTransformer(Model): | ||||
|             device = "cpu" | ||||
|         self.device = device | ||||
|         self.model.to(self.device) | ||||
|         # move the optimizer | ||||
|         for param in self.train_optimizer.state.values(): | ||||
|             # Not sure there are any global tensors in the state dict | ||||
|             if isinstance(param, torch.Tensor): | ||||
|                 param.data = param.data.to(device) | ||||
|                 if param._grad is not None: | ||||
|                     param._grad.data = param._grad.data.to(device) | ||||
|             elif isinstance(param, dict): | ||||
|                 for subparam in param.values(): | ||||
|                     if isinstance(subparam, torch.Tensor): | ||||
|                         subparam.data = subparam.data.to(device) | ||||
|                         if subparam._grad is not None: | ||||
|                             subparam._grad.data = subparam._grad.data.to(device) | ||||
|  | ||||
|     def loss_fn(self, pred, label): | ||||
|         mask = ~torch.isnan(label) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user