Fix bugs in .to(cpu)
This commit is contained in:
		| @@ -143,6 +143,19 @@ class QuantTransformer(Model): | ||||
|             device = "cpu" | ||||
|         self.device = device | ||||
|         self.model.to(self.device) | ||||
|         # move the optimizer | ||||
|         for param in self.train_optimizer.state.values(): | ||||
|             # Not sure there are any global tensors in the state dict | ||||
|             if isinstance(param, torch.Tensor): | ||||
|                 param.data = param.data.to(device) | ||||
|                 if param._grad is not None: | ||||
|                     param._grad.data = param._grad.data.to(device) | ||||
|             elif isinstance(param, dict): | ||||
|                 for subparam in param.values(): | ||||
|                     if isinstance(subparam, torch.Tensor): | ||||
|                         subparam.data = subparam.data.to(device) | ||||
|                         if subparam._grad is not None: | ||||
|                             subparam._grad.data = subparam._grad.data.to(device) | ||||
|  | ||||
|     def loss_fn(self, pred, label): | ||||
|         mask = ~torch.isnan(label) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user