##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # ##################################################### # To be finished. # import os, sys, time, torch from typing import Optional, Text, Callable # modules in AutoDL from xautodl.log_utils import AverageMeter, time_string from .eval_funcs import obtain_accuracy def get_device(tensors): if isinstance(tensors, (list, tuple)): return get_device(tensors[0]) elif isinstance(tensors, dict): for key, value in tensors.items(): return get_device(value) else: return tensors.device def basic_train_fn( xloader, network, criterion, optimizer, metric, logger, ): results = procedure( xloader, network, criterion, optimizer, metric, "train", logger, ) return results def basic_eval_fn(xloader, network, metric, logger): with torch.no_grad(): results = procedure( xloader, network, None, None, metric, "valid", logger, ) return results def procedure( xloader, network, criterion, optimizer, metric, mode: Text, logger_fn: Callable = None, ): data_time, batch_time = AverageMeter(), AverageMeter() if mode.lower() == "train": network.train() elif mode.lower() == "valid": network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) end = time.time() for i, (inputs, targets) in enumerate(xloader): # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss if mode == "train": optimizer.zero_grad() outputs = network(inputs) targets = targets.to(get_device(outputs)) if mode == "train": loss = criterion(outputs, targets) loss.backward() optimizer.step() # record with torch.no_grad(): results = metric(outputs, targets) # measure elapsed time batch_time.update(time.time() - end) end = time.time() return metric.get_info()