Update metrics
This commit is contained in:
parent
248686820c
commit
437f62b055
12
configs/yaml.loss/top1-ce
Normal file
12
configs/yaml.loss/top1-ce
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
class_or_func: ComposeMetric
|
||||||
|
module_path: xautodl.xmisc.meter_utils
|
||||||
|
args:
|
||||||
|
- class_or_func: Top1AccMetric
|
||||||
|
module_path: xautodl.xmisc.meter_utils
|
||||||
|
args: [False]
|
||||||
|
kwargs: {}
|
||||||
|
- class_or_func: CrossEntropyMetric
|
||||||
|
module_path: xautodl.xmisc.meter_utils
|
||||||
|
args: [False]
|
||||||
|
kwargs: {}
|
||||||
|
kwargs: {}
|
4
configs/yaml.loss/top1.acc
Normal file
4
configs/yaml.loss/top1.acc
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class_or_func: Top1AccMetric
|
||||||
|
module_path: xautodl.xmisc.meter_utils
|
||||||
|
args: [False]
|
||||||
|
kwargs: {}
|
@ -69,10 +69,13 @@ def main(args):
|
|||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
objective = xmisc.nested_call_by_yaml(args.loss_config)
|
objective = xmisc.nested_call_by_yaml(args.loss_config)
|
||||||
|
metric = xmisc.nested_call_by_yaml(args.metric_config)
|
||||||
|
|
||||||
logger.log("The optimizer is:\n{:}".format(optimizer))
|
logger.log("The optimizer is:\n{:}".format(optimizer))
|
||||||
logger.log("The objective is {:}".format(objective))
|
logger.log("The objective is {:}".format(objective))
|
||||||
logger.log("The iters_per_epoch={:}".format(iters_per_epoch))
|
logger.log("The metric is {:}".format(metric))
|
||||||
|
logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format(
|
||||||
|
iters_per_epoch, args.steps // iters_per_epoch))
|
||||||
|
|
||||||
model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda()
|
model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda()
|
||||||
scheduler = xmisc.LRMultiplier(
|
scheduler = xmisc.LRMultiplier(
|
||||||
@ -99,6 +102,7 @@ def main(args):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
if xiter % iters_per_epoch == 0:
|
if xiter % iters_per_epoch == 0:
|
||||||
logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item()))
|
logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item()))
|
||||||
|
|
||||||
@ -123,6 +127,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
||||||
parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
|
parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
|
||||||
parser.add_argument("--loss_config", type=str, help="The loss config file.")
|
parser.add_argument("--loss_config", type=str, help="The loss config file.")
|
||||||
|
parser.add_argument("--metric_config", type=str, help="The metric config file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_data_config", type=str, help="The training dataset config path."
|
"--train_data_config", type=str, help="The training dataset config path."
|
||||||
)
|
)
|
||||||
|
@ -28,5 +28,6 @@ python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
|
|||||||
--model_config ./configs/yaml.model/vit-cifar10.s0 \
|
--model_config ./configs/yaml.model/vit-cifar10.s0 \
|
||||||
--optim_config ./configs/yaml.opt/vit.cifar \
|
--optim_config ./configs/yaml.opt/vit.cifar \
|
||||||
--loss_config ./configs/yaml.loss/cross-entropy \
|
--loss_config ./configs/yaml.loss/cross-entropy \
|
||||||
|
--metric_config ./configs/yaml.loss/top-ce \
|
||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000
|
--lr 0.003 --weight_decay 0.3 --scheduler warm-cos --steps 10000
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||||
|
#####################################################
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
"""Computes and stores the average and current value"""
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
@ -20,3 +25,133 @@ class AverageMeter:
|
|||||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||||
name=self.__class__.__name__, **self.__dict__
|
name=self.__class__.__name__, **self.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(abc.ABC):
|
||||||
|
"""The default meta metric class."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, predictions, targets):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def perf_str(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({inner})".format(
|
||||||
|
name=self.__class__.__name__, inner=self.inner_repr()
|
||||||
|
)
|
||||||
|
|
||||||
|
def inner_repr(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class ComposeMetric(Metric):
|
||||||
|
"""The composed metric class."""
|
||||||
|
|
||||||
|
def __init__(self, *metric_list):
|
||||||
|
self.reset()
|
||||||
|
for metric in metric_list:
|
||||||
|
self.append(metric)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._metric_list = []
|
||||||
|
|
||||||
|
def append(self, metric):
|
||||||
|
if not isinstance(metric, Metric):
|
||||||
|
raise ValueError(
|
||||||
|
"The input metric is not correct: {:}".format(type(metric))
|
||||||
|
)
|
||||||
|
self._metric_list.append(metric)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._metric_list)
|
||||||
|
|
||||||
|
def __call__(self, predictions, targets):
|
||||||
|
results = list()
|
||||||
|
for metric in self._metric_list:
|
||||||
|
results.append(metric(predictions, targets))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
results = dict()
|
||||||
|
for metric in self._metric_list:
|
||||||
|
for key, value in metric.get_info().items():
|
||||||
|
results[key] = value
|
||||||
|
return results
|
||||||
|
|
||||||
|
def inner_repr(self):
|
||||||
|
xlist = []
|
||||||
|
for metric in self._metric_list:
|
||||||
|
xlist.append(str(metric))
|
||||||
|
return ",".join(xlist)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyMetric(Metric):
|
||||||
|
"""The metric for the cross entropy metric."""
|
||||||
|
|
||||||
|
def __init__(self, ignore_batch):
|
||||||
|
super(CrossEntropyMetric, self).__init__()
|
||||||
|
self._ignore_batch = ignore_batch
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._loss = AverageMeter()
|
||||||
|
|
||||||
|
def __call__(self, predictions, targets):
|
||||||
|
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||||
|
batch, _ = predictions.shape() # only support 2-D tensor
|
||||||
|
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||||
|
if self._ignore_batch:
|
||||||
|
loss = F.cross_entropy(predictions, targets, reduction="sum")
|
||||||
|
self._loss.update(loss.item(), 1)
|
||||||
|
else:
|
||||||
|
loss = F.cross_entropy(predictions, targets, reduction="mean")
|
||||||
|
self._loss.update(loss.item(), batch)
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
return {"loss": self._loss.avg, "score": self._loss.avg * 100}
|
||||||
|
|
||||||
|
def perf_str(self):
|
||||||
|
return "ce-loss={:.5f}".format(self._loss.avg)
|
||||||
|
|
||||||
|
|
||||||
|
class Top1AccMetric(Metric):
|
||||||
|
"""The metric for the top-1 accuracy."""
|
||||||
|
|
||||||
|
def __init__(self, ignore_batch):
|
||||||
|
super(Top1AccMetric, self).__init__()
|
||||||
|
self._ignore_batch = ignore_batch
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._accuracy = AverageMeter()
|
||||||
|
|
||||||
|
def __call__(self, predictions, targets):
|
||||||
|
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||||
|
batch, _ = predictions.shape() # only support 2-D tensor
|
||||||
|
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||||
|
corrects = torch.eq(max_prob_indexes, targets)
|
||||||
|
accuracy = corrects.float().mean().float()
|
||||||
|
if self._ignore_batch:
|
||||||
|
self._accuracy.update(accuracy, 1)
|
||||||
|
else:
|
||||||
|
self._accuracy.update(accuracy, batch)
|
||||||
|
return accuracy
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
|
||||||
|
|
||||||
|
def perf_str(self):
|
||||||
|
return "accuracy={:.3f}%".format(self._accuracy.avg * 100)
|
||||||
|
Loading…
Reference in New Issue
Block a user