Update metrics
This commit is contained in:
		| @@ -1,3 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # | ||||
| ##################################################### | ||||
|  | ||||
|  | ||||
| class AverageMeter: | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
| @@ -20,3 +25,133 @@ class AverageMeter: | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def perf_str(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class CrossEntropyMetric(Metric): | ||||
|     """The metric for the cross entropy metric.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(CrossEntropyMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._loss = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch, _ = predictions.shape()  # only support 2-D tensor | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             if self._ignore_batch: | ||||
|                 loss = F.cross_entropy(predictions, targets, reduction="sum") | ||||
|                 self._loss.update(loss.item(), 1) | ||||
|             else: | ||||
|                 loss = F.cross_entropy(predictions, targets, reduction="mean") | ||||
|                 self._loss.update(loss.item(), batch) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"loss": self._loss.avg, "score": self._loss.avg * 100} | ||||
|  | ||||
|     def perf_str(self): | ||||
|         return "ce-loss={:.5f}".format(self._loss.avg) | ||||
|  | ||||
|  | ||||
| class Top1AccMetric(Metric): | ||||
|     """The metric for the top-1 accuracy.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(Top1AccMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._accuracy = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch, _ = predictions.shape()  # only support 2-D tensor | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             corrects = torch.eq(max_prob_indexes, targets) | ||||
|             accuracy = corrects.float().mean().float() | ||||
|             if self._ignore_batch: | ||||
|                 self._accuracy.update(accuracy, 1) | ||||
|             else: | ||||
|                 self._accuracy.update(accuracy, batch) | ||||
|             return accuracy | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||
|  | ||||
|     def perf_str(self): | ||||
|         return "accuracy={:.3f}%".format(self._accuracy.avg * 100) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user