Fix bugs
This commit is contained in:
		| @@ -1,8 +1,10 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # | ||||
| ##################################################### | ||||
| # To be finished. | ||||
| # | ||||
| import os, sys, time, torch | ||||
| from typing import import Optional, Text, Callable | ||||
| from typing import Optional, Text, Callable | ||||
|  | ||||
| # modules in AutoDL | ||||
| from log_utils import AverageMeter | ||||
| @@ -60,9 +62,10 @@ def procedure( | ||||
|     network, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     eval_metric, | ||||
|     mode: Text, | ||||
|     print_freq: int = 100, | ||||
|     logger_fn: Callable = None | ||||
|     logger_fn: Callable = None, | ||||
| ): | ||||
|     data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     if mode.lower() == "train": | ||||
| @@ -90,7 +93,7 @@ def procedure( | ||||
|             optimizer.step() | ||||
|  | ||||
|         # record | ||||
|         metrics =  | ||||
|         metrics = eval_metric(logits.data, targets.data) | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1.update(prec1.item(), inputs.size(0)) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
| import abc | ||||
|  | ||||
|  | ||||
| def obtain_accuracy(output, target, topk=(1,)): | ||||
|     """Computes the precision@k for the specified values of k""" | ||||
|     maxk = max(topk) | ||||
| @@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)): | ||||
|  | ||||
|  | ||||
| class EvaluationMetric(abc.ABC): | ||||
|      | ||||
|     def __init__(self): | ||||
|         self._total_metrics = 0 | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user