Update metrics
This commit is contained in:
		
							
								
								
									
										12
									
								
								configs/yaml.loss/top1-ce
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								configs/yaml.loss/top1-ce
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| class_or_func: ComposeMetric | ||||
| module_path: xautodl.xmisc.meter_utils | ||||
| args: | ||||
|   - class_or_func: Top1AccMetric | ||||
|     module_path: xautodl.xmisc.meter_utils | ||||
|     args: [False] | ||||
|     kwargs: {} | ||||
|   - class_or_func: CrossEntropyMetric | ||||
|     module_path: xautodl.xmisc.meter_utils | ||||
|     args: [False] | ||||
|     kwargs: {} | ||||
| kwargs: {} | ||||
							
								
								
									
										4
									
								
								configs/yaml.loss/top1.acc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								configs/yaml.loss/top1.acc
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| class_or_func: Top1AccMetric | ||||
| module_path: xautodl.xmisc.meter_utils | ||||
| args: [False] | ||||
| kwargs: {} | ||||
| @@ -69,10 +69,13 @@ def main(args): | ||||
|         weight_decay=args.weight_decay, | ||||
|     ) | ||||
|     objective = xmisc.nested_call_by_yaml(args.loss_config) | ||||
|     metric = xmisc.nested_call_by_yaml(args.metric_config) | ||||
|  | ||||
|     logger.log("The optimizer is:\n{:}".format(optimizer)) | ||||
|     logger.log("The objective is {:}".format(objective)) | ||||
|     logger.log("The iters_per_epoch={:}".format(iters_per_epoch)) | ||||
|     logger.log("The metric is {:}".format(metric)) | ||||
|     logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format( | ||||
|       iters_per_epoch, args.steps // iters_per_epoch)) | ||||
|  | ||||
|     model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() | ||||
|     scheduler = xmisc.LRMultiplier( | ||||
| @@ -99,6 +102,7 @@ def main(args): | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         scheduler.step() | ||||
|  | ||||
|         if xiter % iters_per_epoch == 0: | ||||
|             logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) | ||||
|  | ||||
| @@ -123,6 +127,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||
|     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") | ||||
|     parser.add_argument("--loss_config", type=str, help="The loss config file.") | ||||
|     parser.add_argument("--metric_config", type=str, help="The metric config file.") | ||||
|     parser.add_argument( | ||||
|         "--train_data_config", type=str, help="The training dataset config path." | ||||
|     ) | ||||
|   | ||||
| @@ -28,5 +28,6 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | ||||
| 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | ||||
| 	--optim_config ./configs/yaml.opt/vit.cifar \ | ||||
| 	--loss_config ./configs/yaml.loss/cross-entropy \ | ||||
| 	--metric_config ./configs/yaml.loss/top-ce \ | ||||
| 	--batch_size 256 \ | ||||
| 	--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 | ||||
|   | ||||
| @@ -1,3 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # | ||||
| ##################################################### | ||||
|  | ||||
|  | ||||
| class AverageMeter: | ||||
|     """Computes and stores the average and current value""" | ||||
|  | ||||
| @@ -20,3 +25,133 @@ class AverageMeter: | ||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class Metric(abc.ABC): | ||||
|     """The default meta metric class.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def perf_str(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({inner})".format( | ||||
|             name=self.__class__.__name__, inner=self.inner_repr() | ||||
|         ) | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class ComposeMetric(Metric): | ||||
|     """The composed metric class.""" | ||||
|  | ||||
|     def __init__(self, *metric_list): | ||||
|         self.reset() | ||||
|         for metric in metric_list: | ||||
|             self.append(metric) | ||||
|  | ||||
|     def reset(self): | ||||
|         self._metric_list = [] | ||||
|  | ||||
|     def append(self, metric): | ||||
|         if not isinstance(metric, Metric): | ||||
|             raise ValueError( | ||||
|                 "The input metric is not correct: {:}".format(type(metric)) | ||||
|             ) | ||||
|         self._metric_list.append(metric) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._metric_list) | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         results = list() | ||||
|         for metric in self._metric_list: | ||||
|             results.append(metric(predictions, targets)) | ||||
|         return results | ||||
|  | ||||
|     def get_info(self): | ||||
|         results = dict() | ||||
|         for metric in self._metric_list: | ||||
|             for key, value in metric.get_info().items(): | ||||
|                 results[key] = value | ||||
|         return results | ||||
|  | ||||
|     def inner_repr(self): | ||||
|         xlist = [] | ||||
|         for metric in self._metric_list: | ||||
|             xlist.append(str(metric)) | ||||
|         return ",".join(xlist) | ||||
|  | ||||
|  | ||||
| class CrossEntropyMetric(Metric): | ||||
|     """The metric for the cross entropy metric.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(CrossEntropyMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._loss = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch, _ = predictions.shape()  # only support 2-D tensor | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             if self._ignore_batch: | ||||
|                 loss = F.cross_entropy(predictions, targets, reduction="sum") | ||||
|                 self._loss.update(loss.item(), 1) | ||||
|             else: | ||||
|                 loss = F.cross_entropy(predictions, targets, reduction="mean") | ||||
|                 self._loss.update(loss.item(), batch) | ||||
|             return loss | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"loss": self._loss.avg, "score": self._loss.avg * 100} | ||||
|  | ||||
|     def perf_str(self): | ||||
|         return "ce-loss={:.5f}".format(self._loss.avg) | ||||
|  | ||||
|  | ||||
| class Top1AccMetric(Metric): | ||||
|     """The metric for the top-1 accuracy.""" | ||||
|  | ||||
|     def __init__(self, ignore_batch): | ||||
|         super(Top1AccMetric, self).__init__() | ||||
|         self._ignore_batch = ignore_batch | ||||
|  | ||||
|     def reset(self): | ||||
|         self._accuracy = AverageMeter() | ||||
|  | ||||
|     def __call__(self, predictions, targets): | ||||
|         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||
|             batch, _ = predictions.shape()  # only support 2-D tensor | ||||
|             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||
|             corrects = torch.eq(max_prob_indexes, targets) | ||||
|             accuracy = corrects.float().mean().float() | ||||
|             if self._ignore_batch: | ||||
|                 self._accuracy.update(accuracy, 1) | ||||
|             else: | ||||
|                 self._accuracy.update(accuracy, batch) | ||||
|             return accuracy | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|  | ||||
|     def get_info(self): | ||||
|         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||
|  | ||||
|     def perf_str(self): | ||||
|         return "accuracy={:.3f}%".format(self._accuracy.avg * 100) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user