100 lines
2.2 KiB
Python
100 lines
2.2 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
|
#####################################################
|
|
# To be finished.
|
|
#
|
|
import os, sys, time, torch
|
|
from typing import Optional, Text, Callable
|
|
|
|
# modules in AutoDL
|
|
from xautodl.log_utils import AverageMeter, time_string
|
|
from .eval_funcs import obtain_accuracy
|
|
|
|
|
|
def get_device(tensors):
|
|
if isinstance(tensors, (list, tuple)):
|
|
return get_device(tensors[0])
|
|
elif isinstance(tensors, dict):
|
|
for key, value in tensors.items():
|
|
return get_device(value)
|
|
else:
|
|
return tensors.device
|
|
|
|
|
|
def basic_train_fn(
|
|
xloader,
|
|
network,
|
|
criterion,
|
|
optimizer,
|
|
metric,
|
|
logger,
|
|
):
|
|
results = procedure(
|
|
xloader,
|
|
network,
|
|
criterion,
|
|
optimizer,
|
|
metric,
|
|
"train",
|
|
logger,
|
|
)
|
|
return results
|
|
|
|
|
|
def basic_eval_fn(xloader, network, metric, logger):
|
|
with torch.no_grad():
|
|
results = procedure(
|
|
xloader,
|
|
network,
|
|
None,
|
|
None,
|
|
metric,
|
|
"valid",
|
|
logger,
|
|
)
|
|
return results
|
|
|
|
|
|
def procedure(
|
|
xloader,
|
|
network,
|
|
criterion,
|
|
optimizer,
|
|
metric,
|
|
mode: Text,
|
|
logger_fn: Callable = None,
|
|
):
|
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
|
if mode.lower() == "train":
|
|
network.train()
|
|
elif mode.lower() == "valid":
|
|
network.eval()
|
|
else:
|
|
raise ValueError("The mode is not right : {:}".format(mode))
|
|
|
|
end = time.time()
|
|
for i, (inputs, targets) in enumerate(xloader):
|
|
# measure data loading time
|
|
data_time.update(time.time() - end)
|
|
# calculate prediction and loss
|
|
|
|
if mode == "train":
|
|
optimizer.zero_grad()
|
|
|
|
outputs = network(inputs)
|
|
targets = targets.to(get_device(outputs))
|
|
|
|
if mode == "train":
|
|
loss = criterion(outputs, targets)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# record
|
|
with torch.no_grad():
|
|
results = metric(outputs, targets)
|
|
|
|
# measure elapsed time
|
|
batch_time.update(time.time() - end)
|
|
end = time.time()
|
|
return metric.get_info()
|