Fix CUDA memory issues
This commit is contained in:
		| @@ -45,6 +45,32 @@ DEFAULT_OPT_CONFIG = dict( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def train_or_test_epoch( | ||||
|     xloader, model, loss_fn, metric_fn, is_train, optimizer, device | ||||
| ): | ||||
|     if is_train: | ||||
|         model.train() | ||||
|     else: | ||||
|         model.eval() | ||||
|     score_meter, loss_meter = AverageMeter(), AverageMeter() | ||||
|     for ibatch, (feats, labels) in enumerate(xloader): | ||||
|         feats, labels = feats.to(device), labels.to(device) | ||||
|         # forward the network | ||||
|         preds = model(feats) | ||||
|         loss = loss_fn(preds, labels) | ||||
|         with torch.no_grad(): | ||||
|             score = metric_fn(preds, labels) | ||||
|             loss_meter.update(loss.item(), feats.size(0)) | ||||
|             score_meter.update(score.item(), feats.size(0)) | ||||
|         # optimize the network | ||||
|         if is_train and optimizer is not None: | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) | ||||
|             optimizer.step() | ||||
|     return loss_meter.avg, score_meter.avg | ||||
|  | ||||
|  | ||||
| class QuantTransformer(Model): | ||||
|     """Transformer-based Quant Model""" | ||||
|  | ||||
| @@ -132,32 +158,6 @@ class QuantTransformer(Model): | ||||
|         else: | ||||
|             raise ValueError("unknown metric `{:}`".format(self.metric)) | ||||
|  | ||||
|     def train_or_test_epoch( | ||||
|         self, xloader, model, loss_fn, metric_fn, is_train, optimizer=None | ||||
|     ): | ||||
|         if is_train: | ||||
|             model.train() | ||||
|         else: | ||||
|             model.eval() | ||||
|         score_meter, loss_meter = AverageMeter(), AverageMeter() | ||||
|         for ibatch, (feats, labels) in enumerate(xloader): | ||||
|             feats = feats.to(self.device, non_blocking=True) | ||||
|             labels = labels.to(self.device, non_blocking=True) | ||||
|             # forward the network | ||||
|             preds = model(feats) | ||||
|             loss = loss_fn(preds, labels) | ||||
|             with torch.no_grad(): | ||||
|                 score = self.metric_fn(preds, labels) | ||||
|                 loss_meter.update(loss.item(), feats.size(0)) | ||||
|                 score_meter.update(score.item(), feats.size(0)) | ||||
|             # optimize the network | ||||
|             if is_train and optimizer is not None: | ||||
|                 optimizer.zero_grad() | ||||
|                 loss.backward() | ||||
|                 torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) | ||||
|                 optimizer.step() | ||||
|         return loss_meter.avg, score_meter.avg | ||||
|  | ||||
|     def fit( | ||||
|         self, | ||||
|         dataset: DatasetH, | ||||
| @@ -204,14 +204,22 @@ class QuantTransformer(Model): | ||||
|  | ||||
|         def _internal_test(ckp_epoch=None, results_dict=None): | ||||
|             with torch.no_grad(): | ||||
|                 train_loss, train_score = self.train_or_test_epoch( | ||||
|                     train_loader, self.model, self.loss_fn, self.metric_fn, False, None | ||||
|                 shared_kwards = { | ||||
|                     "model": self.model, | ||||
|                     "loss_fn": self.loss_fn, | ||||
|                     "metric_fn": self.metric_fn, | ||||
|                     "is_train": False, | ||||
|                     "optimizer": None, | ||||
|                     "device": self.device, | ||||
|                 } | ||||
|                 train_loss, train_score = train_or_test_epoch( | ||||
|                     train_loader, **shared_kwards | ||||
|                 ) | ||||
|                 valid_loss, valid_score = self.train_or_test_epoch( | ||||
|                     valid_loader, self.model, self.loss_fn, self.metric_fn, False, None | ||||
|                 valid_loss, valid_score = train_or_test_epoch( | ||||
|                     valid_loader, **shared_kwards | ||||
|                 ) | ||||
|                 test_loss, test_score = self.train_or_test_epoch( | ||||
|                     test_loader, self.model, self.loss_fn, self.metric_fn, False, None | ||||
|                 test_loss, test_score = train_or_test_epoch( | ||||
|                     test_loader, **shared_kwards | ||||
|                 ) | ||||
|                 xstr = ( | ||||
|                     "train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format( | ||||
| @@ -255,13 +263,14 @@ class QuantTransformer(Model): | ||||
|                     iepoch, self.opt_config["epochs"], best_epoch, best_score | ||||
|                 ) | ||||
|             ) | ||||
|             train_loss, train_score = self.train_or_test_epoch( | ||||
|             train_loss, train_score = train_or_test_epoch( | ||||
|                 train_loader, | ||||
|                 self.model, | ||||
|                 self.loss_fn, | ||||
|                 self.metric_fn, | ||||
|                 True, | ||||
|                 self.train_optimizer, | ||||
|                 self.device, | ||||
|             ) | ||||
|             self.logger.info( | ||||
|                 "Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score) | ||||
| @@ -307,7 +316,8 @@ class QuantTransformer(Model): | ||||
|         self.logger.info("Reload the best parameter :: {:}".format(eval_str)) | ||||
|  | ||||
|         if self.use_gpu: | ||||
|             torch.cuda.empty_cache() | ||||
|             with torch.cuda.device(self.device): | ||||
|                 torch.cuda.empty_cache() | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user