164 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			164 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
 | |
| #####################################################
 | |
| # In this python file, it contains the meter classes#
 | |
| # , which may need to use PyTorch or Numpy.         #
 | |
| #####################################################
 | |
| import abc
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| 
 | |
| class AverageMeter:
 | |
|     """Computes and stores the average and current value"""
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.val = 0.0
 | |
|         self.avg = 0.0
 | |
|         self.sum = 0.0
 | |
|         self.count = 0.0
 | |
| 
 | |
|     def update(self, val, n=1):
 | |
|         self.val = val
 | |
|         self.sum += val * n
 | |
|         self.count += n
 | |
|         self.avg = self.sum / self.count
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}(val={val}, avg={avg}, count={count})".format(
 | |
|             name=self.__class__.__name__, **self.__dict__
 | |
|         )
 | |
| 
 | |
| 
 | |
| class Metric(abc.ABC):
 | |
|     """The default meta metric class."""
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __call__(self, predictions, targets):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def get_info(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def perf_str(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}({inner})".format(
 | |
|             name=self.__class__.__name__, inner=self.inner_repr()
 | |
|         )
 | |
| 
 | |
|     def inner_repr(self):
 | |
|         return ""
 | |
| 
 | |
| 
 | |
| class ComposeMetric(Metric):
 | |
|     """The composed metric class."""
 | |
| 
 | |
|     def __init__(self, *metric_list):
 | |
|         self.reset()
 | |
|         for metric in metric_list:
 | |
|             self.append(metric)
 | |
| 
 | |
|     def reset(self):
 | |
|         self._metric_list = []
 | |
| 
 | |
|     def append(self, metric):
 | |
|         if not isinstance(metric, Metric):
 | |
|             raise ValueError(
 | |
|                 "The input metric is not correct: {:}".format(type(metric))
 | |
|             )
 | |
|         self._metric_list.append(metric)
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self._metric_list)
 | |
| 
 | |
|     def __call__(self, predictions, targets):
 | |
|         results = list()
 | |
|         for metric in self._metric_list:
 | |
|             results.append(metric(predictions, targets))
 | |
|         return results
 | |
| 
 | |
|     def get_info(self):
 | |
|         results = dict()
 | |
|         for metric in self._metric_list:
 | |
|             for key, value in metric.get_info().items():
 | |
|                 results[key] = value
 | |
|         return results
 | |
| 
 | |
|     def inner_repr(self):
 | |
|         xlist = []
 | |
|         for metric in self._metric_list:
 | |
|             xlist.append(str(metric))
 | |
|         return ",".join(xlist)
 | |
| 
 | |
| 
 | |
| class CrossEntropyMetric(Metric):
 | |
|     """The metric for the cross entropy metric."""
 | |
| 
 | |
|     def __init__(self, ignore_batch):
 | |
|         super(CrossEntropyMetric, self).__init__()
 | |
|         self._ignore_batch = ignore_batch
 | |
| 
 | |
|     def reset(self):
 | |
|         self._loss = AverageMeter()
 | |
| 
 | |
|     def __call__(self, predictions, targets):
 | |
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
 | |
|             batch, _ = predictions.shape()  # only support 2-D tensor
 | |
|             max_prob_indexes = torch.argmax(predictions, dim=-1)
 | |
|             if self._ignore_batch:
 | |
|                 loss = F.cross_entropy(predictions, targets, reduction="sum")
 | |
|                 self._loss.update(loss.item(), 1)
 | |
|             else:
 | |
|                 loss = F.cross_entropy(predictions, targets, reduction="mean")
 | |
|                 self._loss.update(loss.item(), batch)
 | |
|             return loss
 | |
|         else:
 | |
|             raise NotImplementedError
 | |
| 
 | |
|     def get_info(self):
 | |
|         return {"loss": self._loss.avg, "score": self._loss.avg * 100}
 | |
| 
 | |
|     def perf_str(self):
 | |
|         return "ce-loss={:.5f}".format(self._loss.avg)
 | |
| 
 | |
| 
 | |
| class Top1AccMetric(Metric):
 | |
|     """The metric for the top-1 accuracy."""
 | |
| 
 | |
|     def __init__(self, ignore_batch):
 | |
|         super(Top1AccMetric, self).__init__()
 | |
|         self._ignore_batch = ignore_batch
 | |
| 
 | |
|     def reset(self):
 | |
|         self._accuracy = AverageMeter()
 | |
| 
 | |
|     def __call__(self, predictions, targets):
 | |
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
 | |
|             batch, _ = predictions.shape()  # only support 2-D tensor
 | |
|             max_prob_indexes = torch.argmax(predictions, dim=-1)
 | |
|             corrects = torch.eq(max_prob_indexes, targets)
 | |
|             accuracy = corrects.float().mean().float()
 | |
|             if self._ignore_batch:
 | |
|                 self._accuracy.update(accuracy, 1)
 | |
|             else:
 | |
|                 self._accuracy.update(accuracy, batch)
 | |
|             return accuracy
 | |
|         else:
 | |
|             raise NotImplementedError
 | |
| 
 | |
|     def get_info(self):
 | |
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
 | |
| 
 | |
|     def perf_str(self):
 | |
|         return "accuracy={:.3f}%".format(self._accuracy.avg * 100)
 |