From 437f62b055d846c9b30e0d00b916b258972a8669 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 11 Jun 2021 07:20:06 -0700 Subject: [PATCH] Update metrics --- configs/yaml.loss/top1-ce | 12 +++ configs/yaml.loss/top1.acc | 4 + exps/basic/xmain.py | 7 +- scripts/experimental/train-vit.sh | 1 + xautodl/xmisc/meter_utils.py | 135 ++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 configs/yaml.loss/top1-ce create mode 100644 configs/yaml.loss/top1.acc diff --git a/configs/yaml.loss/top1-ce b/configs/yaml.loss/top1-ce new file mode 100644 index 0000000..f82f7b3 --- /dev/null +++ b/configs/yaml.loss/top1-ce @@ -0,0 +1,12 @@ +class_or_func: ComposeMetric +module_path: xautodl.xmisc.meter_utils +args: + - class_or_func: Top1AccMetric + module_path: xautodl.xmisc.meter_utils + args: [False] + kwargs: {} + - class_or_func: CrossEntropyMetric + module_path: xautodl.xmisc.meter_utils + args: [False] + kwargs: {} +kwargs: {} diff --git a/configs/yaml.loss/top1.acc b/configs/yaml.loss/top1.acc new file mode 100644 index 0000000..aae4d60 --- /dev/null +++ b/configs/yaml.loss/top1.acc @@ -0,0 +1,4 @@ +class_or_func: Top1AccMetric +module_path: xautodl.xmisc.meter_utils +args: [False] +kwargs: {} diff --git a/exps/basic/xmain.py b/exps/basic/xmain.py index 7a0217a..cffc590 100644 --- a/exps/basic/xmain.py +++ b/exps/basic/xmain.py @@ -69,10 +69,13 @@ def main(args): weight_decay=args.weight_decay, ) objective = xmisc.nested_call_by_yaml(args.loss_config) + metric = xmisc.nested_call_by_yaml(args.metric_config) logger.log("The optimizer is:\n{:}".format(optimizer)) logger.log("The objective is {:}".format(objective)) - logger.log("The iters_per_epoch={:}".format(iters_per_epoch)) + logger.log("The metric is {:}".format(metric)) + logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format( + iters_per_epoch, args.steps // iters_per_epoch)) model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() scheduler = xmisc.LRMultiplier( @@ -99,6 +102,7 @@ def main(args): loss.backward() optimizer.step() scheduler.step() + if xiter % iters_per_epoch == 0: logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) @@ -123,6 +127,7 @@ if __name__ == "__main__": parser.add_argument("--model_config", type=str, help="The path to the model config") parser.add_argument("--optim_config", type=str, help="The optimizer config file.") parser.add_argument("--loss_config", type=str, help="The loss config file.") + parser.add_argument("--metric_config", type=str, help="The metric config file.") parser.add_argument( "--train_data_config", type=str, help="The training dataset config path." ) diff --git a/scripts/experimental/train-vit.sh b/scripts/experimental/train-vit.sh index 84e2b48..3a36974 100644 --- a/scripts/experimental/train-vit.sh +++ b/scripts/experimental/train-vit.sh @@ -28,5 +28,6 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \ --model_config ./configs/yaml.model/vit-cifar10.s0 \ --optim_config ./configs/yaml.opt/vit.cifar \ --loss_config ./configs/yaml.loss/cross-entropy \ + --metric_config ./configs/yaml.loss/top-ce \ --batch_size 256 \ --lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 diff --git a/xautodl/xmisc/meter_utils.py b/xautodl/xmisc/meter_utils.py index 923db1a..0e4ac43 100644 --- a/xautodl/xmisc/meter_utils.py +++ b/xautodl/xmisc/meter_utils.py @@ -1,3 +1,8 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # +##################################################### + + class AverageMeter: """Computes and stores the average and current value""" @@ -20,3 +25,133 @@ class AverageMeter: return "{name}(val={val}, avg={avg}, count={count})".format( name=self.__class__.__name__, **self.__dict__ ) + + +class Metric(abc.ABC): + """The default meta metric class.""" + + def __init__(self): + self.reset() + + def reset(self): + raise NotImplementedError + + def __call__(self, predictions, targets): + raise NotImplementedError + + def get_info(self): + raise NotImplementedError + + def perf_str(self): + raise NotImplementedError + + def __repr__(self): + return "{name}({inner})".format( + name=self.__class__.__name__, inner=self.inner_repr() + ) + + def inner_repr(self): + return "" + + +class ComposeMetric(Metric): + """The composed metric class.""" + + def __init__(self, *metric_list): + self.reset() + for metric in metric_list: + self.append(metric) + + def reset(self): + self._metric_list = [] + + def append(self, metric): + if not isinstance(metric, Metric): + raise ValueError( + "The input metric is not correct: {:}".format(type(metric)) + ) + self._metric_list.append(metric) + + def __len__(self): + return len(self._metric_list) + + def __call__(self, predictions, targets): + results = list() + for metric in self._metric_list: + results.append(metric(predictions, targets)) + return results + + def get_info(self): + results = dict() + for metric in self._metric_list: + for key, value in metric.get_info().items(): + results[key] = value + return results + + def inner_repr(self): + xlist = [] + for metric in self._metric_list: + xlist.append(str(metric)) + return ",".join(xlist) + + +class CrossEntropyMetric(Metric): + """The metric for the cross entropy metric.""" + + def __init__(self, ignore_batch): + super(CrossEntropyMetric, self).__init__() + self._ignore_batch = ignore_batch + + def reset(self): + self._loss = AverageMeter() + + def __call__(self, predictions, targets): + if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): + batch, _ = predictions.shape() # only support 2-D tensor + max_prob_indexes = torch.argmax(predictions, dim=-1) + if self._ignore_batch: + loss = F.cross_entropy(predictions, targets, reduction="sum") + self._loss.update(loss.item(), 1) + else: + loss = F.cross_entropy(predictions, targets, reduction="mean") + self._loss.update(loss.item(), batch) + return loss + else: + raise NotImplementedError + + def get_info(self): + return {"loss": self._loss.avg, "score": self._loss.avg * 100} + + def perf_str(self): + return "ce-loss={:.5f}".format(self._loss.avg) + + +class Top1AccMetric(Metric): + """The metric for the top-1 accuracy.""" + + def __init__(self, ignore_batch): + super(Top1AccMetric, self).__init__() + self._ignore_batch = ignore_batch + + def reset(self): + self._accuracy = AverageMeter() + + def __call__(self, predictions, targets): + if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor): + batch, _ = predictions.shape() # only support 2-D tensor + max_prob_indexes = torch.argmax(predictions, dim=-1) + corrects = torch.eq(max_prob_indexes, targets) + accuracy = corrects.float().mean().float() + if self._ignore_batch: + self._accuracy.update(accuracy, 1) + else: + self._accuracy.update(accuracy, batch) + return accuracy + else: + raise NotImplementedError + + def get_info(self): + return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100} + + def perf_str(self): + return "accuracy={:.3f}%".format(self._accuracy.avg * 100)