Update metrics

This commit is contained in:
D-X-Y 2021-06-11 07:20:06 -07:00
parent 248686820c
commit 437f62b055
5 changed files with 158 additions and 1 deletions

12
configs/yaml.loss/top1-ce Normal file
View File

@ -0,0 +1,12 @@
class_or_func: ComposeMetric
module_path: xautodl.xmisc.meter_utils
args:
- class_or_func: Top1AccMetric
module_path: xautodl.xmisc.meter_utils
args: [False]
kwargs: {}
- class_or_func: CrossEntropyMetric
module_path: xautodl.xmisc.meter_utils
args: [False]
kwargs: {}
kwargs: {}

View File

@ -0,0 +1,4 @@
class_or_func: Top1AccMetric
module_path: xautodl.xmisc.meter_utils
args: [False]
kwargs: {}

View File

@ -69,10 +69,13 @@ def main(args):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
objective = xmisc.nested_call_by_yaml(args.loss_config) objective = xmisc.nested_call_by_yaml(args.loss_config)
metric = xmisc.nested_call_by_yaml(args.metric_config)
logger.log("The optimizer is:\n{:}".format(optimizer)) logger.log("The optimizer is:\n{:}".format(optimizer))
logger.log("The objective is {:}".format(objective)) logger.log("The objective is {:}".format(objective))
logger.log("The iters_per_epoch={:}".format(iters_per_epoch)) logger.log("The metric is {:}".format(metric))
logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format(
iters_per_epoch, args.steps // iters_per_epoch))
model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda() model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda()
scheduler = xmisc.LRMultiplier( scheduler = xmisc.LRMultiplier(
@ -99,6 +102,7 @@ def main(args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if xiter % iters_per_epoch == 0: if xiter % iters_per_epoch == 0:
logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item())) logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item()))
@ -123,6 +127,7 @@ if __name__ == "__main__":
parser.add_argument("--model_config", type=str, help="The path to the model config") parser.add_argument("--model_config", type=str, help="The path to the model config")
parser.add_argument("--optim_config", type=str, help="The optimizer config file.") parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
parser.add_argument("--loss_config", type=str, help="The loss config file.") parser.add_argument("--loss_config", type=str, help="The loss config file.")
parser.add_argument("--metric_config", type=str, help="The metric config file.")
parser.add_argument( parser.add_argument(
"--train_data_config", type=str, help="The training dataset config path." "--train_data_config", type=str, help="The training dataset config path."
) )

View File

@ -28,5 +28,6 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
--model_config ./configs/yaml.model/vit-cifar10.s0 \ --model_config ./configs/yaml.model/vit-cifar10.s0 \
--optim_config ./configs/yaml.opt/vit.cifar \ --optim_config ./configs/yaml.opt/vit.cifar \
--loss_config ./configs/yaml.loss/cross-entropy \ --loss_config ./configs/yaml.loss/cross-entropy \
--metric_config ./configs/yaml.loss/top-ce \
--batch_size 256 \ --batch_size 256 \
--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000 --lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000

View File

@ -1,3 +1,8 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
#####################################################
class AverageMeter: class AverageMeter:
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
@ -20,3 +25,133 @@ class AverageMeter:
return "{name}(val={val}, avg={avg}, count={count})".format( return "{name}(val={val}, avg={avg}, count={count})".format(
name=self.__class__.__name__, **self.__dict__ name=self.__class__.__name__, **self.__dict__
) )
class Metric(abc.ABC):
"""The default meta metric class."""
def __init__(self):
self.reset()
def reset(self):
raise NotImplementedError
def __call__(self, predictions, targets):
raise NotImplementedError
def get_info(self):
raise NotImplementedError
def perf_str(self):
raise NotImplementedError
def __repr__(self):
return "{name}({inner})".format(
name=self.__class__.__name__, inner=self.inner_repr()
)
def inner_repr(self):
return ""
class ComposeMetric(Metric):
"""The composed metric class."""
def __init__(self, *metric_list):
self.reset()
for metric in metric_list:
self.append(metric)
def reset(self):
self._metric_list = []
def append(self, metric):
if not isinstance(metric, Metric):
raise ValueError(
"The input metric is not correct: {:}".format(type(metric))
)
self._metric_list.append(metric)
def __len__(self):
return len(self._metric_list)
def __call__(self, predictions, targets):
results = list()
for metric in self._metric_list:
results.append(metric(predictions, targets))
return results
def get_info(self):
results = dict()
for metric in self._metric_list:
for key, value in metric.get_info().items():
results[key] = value
return results
def inner_repr(self):
xlist = []
for metric in self._metric_list:
xlist.append(str(metric))
return ",".join(xlist)
class CrossEntropyMetric(Metric):
"""The metric for the cross entropy metric."""
def __init__(self, ignore_batch):
super(CrossEntropyMetric, self).__init__()
self._ignore_batch = ignore_batch
def reset(self):
self._loss = AverageMeter()
def __call__(self, predictions, targets):
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
batch, _ = predictions.shape() # only support 2-D tensor
max_prob_indexes = torch.argmax(predictions, dim=-1)
if self._ignore_batch:
loss = F.cross_entropy(predictions, targets, reduction="sum")
self._loss.update(loss.item(), 1)
else:
loss = F.cross_entropy(predictions, targets, reduction="mean")
self._loss.update(loss.item(), batch)
return loss
else:
raise NotImplementedError
def get_info(self):
return {"loss": self._loss.avg, "score": self._loss.avg * 100}
def perf_str(self):
return "ce-loss={:.5f}".format(self._loss.avg)
class Top1AccMetric(Metric):
"""The metric for the top-1 accuracy."""
def __init__(self, ignore_batch):
super(Top1AccMetric, self).__init__()
self._ignore_batch = ignore_batch
def reset(self):
self._accuracy = AverageMeter()
def __call__(self, predictions, targets):
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
batch, _ = predictions.shape() # only support 2-D tensor
max_prob_indexes = torch.argmax(predictions, dim=-1)
corrects = torch.eq(max_prob_indexes, targets)
accuracy = corrects.float().mean().float()
if self._ignore_batch:
self._accuracy.update(accuracy, 1)
else:
self._accuracy.update(accuracy, batch)
return accuracy
else:
raise NotImplementedError
def get_info(self):
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
def perf_str(self):
return "accuracy={:.3f}%".format(self._accuracy.avg * 100)