Update metrics
This commit is contained in:
		
							
								
								
									
										12
									
								
								configs/yaml.loss/top1-ce
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								configs/yaml.loss/top1-ce
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | |||||||
|  | class_or_func: ComposeMetric | ||||||
|  | module_path: xautodl.xmisc.meter_utils | ||||||
|  | args: | ||||||
|  |   - class_or_func: Top1AccMetric | ||||||
|  |     module_path: xautodl.xmisc.meter_utils | ||||||
|  |     args: [False] | ||||||
|  |     kwargs: {} | ||||||
|  |   - class_or_func: CrossEntropyMetric | ||||||
|  |     module_path: xautodl.xmisc.meter_utils | ||||||
|  |     args: [False] | ||||||
|  |     kwargs: {} | ||||||
|  | kwargs: {} | ||||||
							
								
								
									
										4
									
								
								configs/yaml.loss/top1.acc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								configs/yaml.loss/top1.acc
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | class_or_func: Top1AccMetric | ||||||
|  | module_path: xautodl.xmisc.meter_utils | ||||||
|  | args: [False] | ||||||
|  | kwargs: {} | ||||||
| @@ -69,10 +69,13 @@ def main(args): | |||||||
|         weight_decay=args.weight_decay, |         weight_decay=args.weight_decay, | ||||||
|     ) |     ) | ||||||
|     objective = xmisc.nested_call_by_yaml(args.loss_config) |     objective = xmisc.nested_call_by_yaml(args.loss_config) | ||||||
|  |     metric = xmisc.nested_call_by_yaml(args.metric_config) | ||||||
|  |  | ||||||
|     logger.log("The optimizer is:\n{:}".format(optimizer)) |     logger.log("The optimizer is:\n{:}".format(optimizer)) | ||||||
|     logger.log("The objective is {:}".format(objective)) |     logger.log("The objective is {:}".format(objective)) | ||||||
|     logger.log("The iters_per_epoch={:}".format(iters_per_epoch)) |     logger.log("The metric is {:}".format(metric)) | ||||||
|  |     logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format( | ||||||
|  |       iters_per_epoch, args.steps // iters_per_epoch)) | ||||||
|  |  | ||||||
|     model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() |     model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() | ||||||
|     scheduler = xmisc.LRMultiplier( |     scheduler = xmisc.LRMultiplier( | ||||||
| @@ -99,6 +102,7 @@ def main(args): | |||||||
|         loss.backward() |         loss.backward() | ||||||
|         optimizer.step() |         optimizer.step() | ||||||
|         scheduler.step() |         scheduler.step() | ||||||
|  |  | ||||||
|         if xiter % iters_per_epoch == 0: |         if xiter % iters_per_epoch == 0: | ||||||
|             logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) |             logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) | ||||||
|  |  | ||||||
| @@ -123,6 +127,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument("--model_config", type=str, help="The path to the model config") |     parser.add_argument("--model_config", type=str, help="The path to the model config") | ||||||
|     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") |     parser.add_argument("--optim_config", type=str, help="The optimizer config file.") | ||||||
|     parser.add_argument("--loss_config", type=str, help="The loss config file.") |     parser.add_argument("--loss_config", type=str, help="The loss config file.") | ||||||
|  |     parser.add_argument("--metric_config", type=str, help="The metric config file.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--train_data_config", type=str, help="The training dataset config path." |         "--train_data_config", type=str, help="The training dataset config path." | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -28,5 +28,6 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ | |||||||
| 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | 	--model_config ./configs/yaml.model/vit-cifar10.s0 \ | ||||||
| 	--optim_config ./configs/yaml.opt/vit.cifar \ | 	--optim_config ./configs/yaml.opt/vit.cifar \ | ||||||
| 	--loss_config ./configs/yaml.loss/cross-entropy \ | 	--loss_config ./configs/yaml.loss/cross-entropy \ | ||||||
|  | 	--metric_config ./configs/yaml.loss/top-ce \ | ||||||
| 	--batch_size 256 \ | 	--batch_size 256 \ | ||||||
| 	--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 | 	--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 | ||||||
|   | |||||||
| @@ -1,3 +1,8 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # | ||||||
|  | ##################################################### | ||||||
|  |  | ||||||
|  |  | ||||||
| class AverageMeter: | class AverageMeter: | ||||||
|     """Computes and stores the average and current value""" |     """Computes and stores the average and current value""" | ||||||
|  |  | ||||||
| @@ -20,3 +25,133 @@ class AverageMeter: | |||||||
|         return "{name}(val={val}, avg={avg}, count={count})".format( |         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||||
|             name=self.__class__.__name__, **self.__dict__ |             name=self.__class__.__name__, **self.__dict__ | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Metric(abc.ABC): | ||||||
|  |     """The default meta metric class.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.reset() | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def perf_str(self): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({inner})".format( | ||||||
|  |             name=self.__class__.__name__, inner=self.inner_repr() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def inner_repr(self): | ||||||
|  |         return "" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ComposeMetric(Metric): | ||||||
|  |     """The composed metric class.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, *metric_list): | ||||||
|  |         self.reset() | ||||||
|  |         for metric in metric_list: | ||||||
|  |             self.append(metric) | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._metric_list = [] | ||||||
|  |  | ||||||
|  |     def append(self, metric): | ||||||
|  |         if not isinstance(metric, Metric): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "The input metric is not correct: {:}".format(type(metric)) | ||||||
|  |             ) | ||||||
|  |         self._metric_list.append(metric) | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._metric_list) | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         results = list() | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             results.append(metric(predictions, targets)) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         results = dict() | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             for key, value in metric.get_info().items(): | ||||||
|  |                 results[key] = value | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def inner_repr(self): | ||||||
|  |         xlist = [] | ||||||
|  |         for metric in self._metric_list: | ||||||
|  |             xlist.append(str(metric)) | ||||||
|  |         return ",".join(xlist) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CrossEntropyMetric(Metric): | ||||||
|  |     """The metric for the cross entropy metric.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, ignore_batch): | ||||||
|  |         super(CrossEntropyMetric, self).__init__() | ||||||
|  |         self._ignore_batch = ignore_batch | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._loss = AverageMeter() | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||||
|  |             batch, _ = predictions.shape()  # only support 2-D tensor | ||||||
|  |             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||||
|  |             if self._ignore_batch: | ||||||
|  |                 loss = F.cross_entropy(predictions, targets, reduction="sum") | ||||||
|  |                 self._loss.update(loss.item(), 1) | ||||||
|  |             else: | ||||||
|  |                 loss = F.cross_entropy(predictions, targets, reduction="mean") | ||||||
|  |                 self._loss.update(loss.item(), batch) | ||||||
|  |             return loss | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         return {"loss": self._loss.avg, "score": self._loss.avg * 100} | ||||||
|  |  | ||||||
|  |     def perf_str(self): | ||||||
|  |         return "ce-loss={:.5f}".format(self._loss.avg) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Top1AccMetric(Metric): | ||||||
|  |     """The metric for the top-1 accuracy.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, ignore_batch): | ||||||
|  |         super(Top1AccMetric, self).__init__() | ||||||
|  |         self._ignore_batch = ignore_batch | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self._accuracy = AverageMeter() | ||||||
|  |  | ||||||
|  |     def __call__(self, predictions, targets): | ||||||
|  |         if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): | ||||||
|  |             batch, _ = predictions.shape()  # only support 2-D tensor | ||||||
|  |             max_prob_indexes = torch.argmax(predictions, dim=-1) | ||||||
|  |             corrects = torch.eq(max_prob_indexes, targets) | ||||||
|  |             accuracy = corrects.float().mean().float() | ||||||
|  |             if self._ignore_batch: | ||||||
|  |                 self._accuracy.update(accuracy, 1) | ||||||
|  |             else: | ||||||
|  |                 self._accuracy.update(accuracy, batch) | ||||||
|  |             return accuracy | ||||||
|  |         else: | ||||||
|  |             raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_info(self): | ||||||
|  |         return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} | ||||||
|  |  | ||||||
|  |     def perf_str(self): | ||||||
|  |         return "accuracy={:.3f}%".format(self._accuracy.avg * 100) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user