Add simple baseline for LFNA
This commit is contained in:
		| @@ -12,49 +12,48 @@ from log_utils import time_string | ||||
| from .eval_funcs import obtain_accuracy | ||||
|  | ||||
|  | ||||
| def basic_train( | ||||
| def get_device(tensors): | ||||
|     if isinstance(tensors, (list, tuple)): | ||||
|         return get_device(tensors[0]) | ||||
|     elif isinstance(tensors, dict): | ||||
|         for key, value in tensors.items(): | ||||
|             return get_device(value) | ||||
|     else: | ||||
|         return tensors.device | ||||
|  | ||||
|  | ||||
| def basic_train_fn( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     optim_config, | ||||
|     extra_info, | ||||
|     print_freq, | ||||
|     metric, | ||||
|     logger, | ||||
| ): | ||||
|     loss, acc1, acc5 = procedure( | ||||
|     results = procedure( | ||||
|         xloader, | ||||
|         network, | ||||
|         criterion, | ||||
|         scheduler, | ||||
|         optimizer, | ||||
|         metric, | ||||
|         "train", | ||||
|         optim_config, | ||||
|         extra_info, | ||||
|         print_freq, | ||||
|         logger, | ||||
|     ) | ||||
|     return loss, acc1, acc5 | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def basic_valid( | ||||
|     xloader, network, criterion, optim_config, extra_info, print_freq, logger | ||||
| ): | ||||
| def basic_eval_fn(xloader, network, metric, logger): | ||||
|     with torch.no_grad(): | ||||
|         loss, acc1, acc5 = procedure( | ||||
|         results = procedure( | ||||
|             xloader, | ||||
|             network, | ||||
|             criterion, | ||||
|             None, | ||||
|             None, | ||||
|             metric, | ||||
|             "valid", | ||||
|             None, | ||||
|             extra_info, | ||||
|             print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|     return loss, acc1, acc5 | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def procedure( | ||||
| @@ -62,12 +61,11 @@ def procedure( | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     eval_metric, | ||||
|     metric, | ||||
|     mode: Text, | ||||
|     print_freq: int = 100, | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
|         network.train() | ||||
|     elif mode.lower() == "valid": | ||||
| @@ -80,49 +78,23 @@ def procedure( | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|         # calculate prediction and loss | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|  | ||||
|         if mode == "train": | ||||
|             optimizer.zero_grad() | ||||
|  | ||||
|         outputs = network(inputs) | ||||
|         loss = criterion(outputs, targets) | ||||
|         targets = targets.to(get_device(outputs)) | ||||
|  | ||||
|         if mode == "train": | ||||
|             loss = criterion(outputs, targets) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         metrics = eval_metric(logits.data, targets.data) | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|         top5.update(prec5.item(), inputs.size(0)) | ||||
|         with torch.no_grad(): | ||||
|             results = metric(outputs, targets) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if i % print_freq == 0 or (i + 1) == len(xloader): | ||||
|             Sstr = ( | ||||
|                 " {:5s} ".format(mode.upper()) | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader)) | ||||
|             ) | ||||
|             Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( | ||||
|                 loss=losses, top1=top1, top5=top5 | ||||
|             ) | ||||
|             Istr = "Size={:}".format(list(inputs.size())) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) | ||||
|  | ||||
|     logger.log( | ||||
|         " **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format( | ||||
|             mode=mode.upper(), | ||||
|             top1=top1, | ||||
|             top5=top5, | ||||
|             error1=100 - top1.avg, | ||||
|             error5=100 - top5.avg, | ||||
|             loss=losses.avg, | ||||
|         ) | ||||
|     ) | ||||
|     return losses.avg, top1.avg, top5.avg | ||||
|     return metric.get_info() | ||||
|   | ||||
| @@ -18,11 +18,3 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|  | ||||
|  | ||||
| class EvaluationMetric(abc.ABC): | ||||
|     def __init__(self): | ||||
|         self._total_metrics = 0 | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._total_metrics | ||||
|   | ||||
							
								
								
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								lib/procedures/metric_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class AverageMeter(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.val = 0.0 | ||||
|         self.avg = 0.0 | ||||
|         self.sum = 0.0 | ||||
|         self.count = 0.0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.val = val | ||||
|         self.sum += val * n | ||||
|         self.count += n | ||||
|         self.avg = self.sum / self.count | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class MSEMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._mse = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch = predictions.shape[0] | ||||
|             loss = torch.nn.functional.mse_loss(predictions.data, targets.data) | ||||
|             loss = loss.item() | ||||
|             self._mse.update(loss, batch) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"mse": self._mse.avg} | ||||
|  | ||||
|  | ||||
| class SaveMetric(Metric): | ||||
|     """The metric for mse.""" | ||||
|  | ||||
|     def reset(self): | ||||
|         self._predicts = [] | ||||
|  | ||||
|     def __call__(self, predictions, targets=None): | ||||
|         if isinstance(predictions, torch.Tensor): | ||||
|             predicts = predictions.cpu().numpy() | ||||
|             self._predicts.append(predicts) | ||||
|             return predicts | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         all_predicts = np.concatenate(self._predicts) | ||||
|         return {"predictions": all_predicts} | ||||
		Reference in New Issue
	
	Block a user