420 lines
17 KiB
Python
420 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Tools for training and testing a model."""
|
|
|
|
import os
|
|
from thop import profile
|
|
|
|
import numpy as np
|
|
import pycls.core.benchmark as benchmark
|
|
import pycls.core.builders as builders
|
|
import pycls.core.checkpoint as checkpoint
|
|
import pycls.core.config as config
|
|
import pycls.core.distributed as dist
|
|
import pycls.core.logging as logging
|
|
import pycls.core.meters as meters
|
|
import pycls.core.net as net
|
|
import pycls.core.optimizer as optim
|
|
import pycls.datasets.loader as loader
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from pycls.core.config import cfg
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def setup_env():
|
|
"""Sets up environment for training or testing."""
|
|
if dist.is_master_proc():
|
|
# Ensure that the output dir exists
|
|
os.makedirs(cfg.OUT_DIR, exist_ok=True)
|
|
# Save the config
|
|
config.dump_cfg()
|
|
# Setup logging
|
|
logging.setup_logging()
|
|
# Log the config as both human readable and as a json
|
|
logger.info("Config:\n{}".format(cfg))
|
|
logger.info(logging.dump_log_data(cfg, "cfg"))
|
|
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
|
|
np.random.seed(cfg.RNG_SEED)
|
|
torch.manual_seed(cfg.RNG_SEED)
|
|
# Configure the CUDNN backend
|
|
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
|
|
|
|
|
|
def setup_model():
|
|
"""Sets up a model for training or testing and log the results."""
|
|
# Build the model
|
|
model = builders.build_model()
|
|
logger.info("Model:\n{}".format(model))
|
|
# Log model complexity
|
|
# logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
|
|
if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes":
|
|
h, w = 1025, 2049
|
|
else:
|
|
h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE
|
|
if cfg.TASK == "jig":
|
|
x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w)
|
|
else:
|
|
x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w)
|
|
macs, params = profile(model, inputs=(x, ), verbose=False)
|
|
logger.info("Params: {:,}".format(params))
|
|
logger.info("Flops: {:,}".format(macs))
|
|
# Transfer the model to the current GPU device
|
|
err_str = "Cannot use more GPU devices than available"
|
|
assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str
|
|
cur_device = torch.cuda.current_device()
|
|
model = model.cuda(device=cur_device)
|
|
# Use multi-process data parallel model in the multi-gpu setting
|
|
if cfg.NUM_GPUS > 1:
|
|
# Make model replica operate on the current device
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
module=model, device_ids=[cur_device], output_device=cur_device
|
|
)
|
|
# Set complexity function to be module's complexity function
|
|
# model.complexity = model.module.complexity
|
|
return model
|
|
|
|
|
|
def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
|
"""Performs one epoch of training."""
|
|
# Update drop path prob for NAS
|
|
if cfg.MODEL.TYPE == "nas":
|
|
m = model.module if cfg.NUM_GPUS > 1 else model
|
|
m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH)
|
|
# Shuffle the data
|
|
loader.shuffle(train_loader, cur_epoch)
|
|
# Update the learning rate per epoch
|
|
if not cfg.OPTIM.ITER_LR:
|
|
lr = optim.get_epoch_lr(cur_epoch)
|
|
optim.set_lr(optimizer, lr)
|
|
# Enable training mode
|
|
model.train()
|
|
train_meter.iter_tic()
|
|
for cur_iter, (inputs, labels) in enumerate(train_loader):
|
|
# Update the learning rate per iter
|
|
if cfg.OPTIM.ITER_LR:
|
|
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader))
|
|
optim.set_lr(optimizer, lr)
|
|
# Transfer the data to the current GPU device
|
|
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
|
# Perform the forward pass
|
|
preds = model(inputs)
|
|
# Compute the loss
|
|
if isinstance(preds, tuple):
|
|
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
|
preds = preds[0]
|
|
else:
|
|
loss = loss_fun(preds, labels)
|
|
# Perform the backward pass
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
# Update the parameters
|
|
optimizer.step()
|
|
# Compute the errors
|
|
if cfg.TASK == "col":
|
|
preds = preds.permute(0, 2, 3, 1)
|
|
preds = preds.reshape(-1, preds.size(3))
|
|
labels = labels.reshape(-1)
|
|
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
|
else:
|
|
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
|
if cfg.TASK == "seg":
|
|
# top1_err is in fact inter; top5_err is in fact union
|
|
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
|
else:
|
|
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
|
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
|
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
|
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
|
# Copy the stats from GPU to CPU (sync point)
|
|
loss = loss.item()
|
|
if cfg.TASK == "seg":
|
|
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
|
else:
|
|
top1_err, top5_err = top1_err.item(), top5_err.item()
|
|
train_meter.iter_toc()
|
|
# Update and log stats
|
|
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
|
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
|
train_meter.iter_tic()
|
|
# Log epoch stats
|
|
train_meter.log_epoch_stats(cur_epoch)
|
|
train_meter.reset()
|
|
|
|
|
|
def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
|
|
"""Performs one epoch of differentiable architecture search."""
|
|
m = model.module if cfg.NUM_GPUS > 1 else model
|
|
# Shuffle the data
|
|
loader.shuffle(train_loader[0], cur_epoch)
|
|
loader.shuffle(train_loader[1], cur_epoch)
|
|
# Update the learning rate per epoch
|
|
if not cfg.OPTIM.ITER_LR:
|
|
lr = optim.get_epoch_lr(cur_epoch)
|
|
optim.set_lr(optimizer[0], lr)
|
|
# Enable training mode
|
|
model.train()
|
|
train_meter.iter_tic()
|
|
trainB_iter = iter(train_loader[1])
|
|
for cur_iter, (inputs, labels) in enumerate(train_loader[0]):
|
|
# Update the learning rate per iter
|
|
if cfg.OPTIM.ITER_LR:
|
|
lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0]))
|
|
optim.set_lr(optimizer[0], lr)
|
|
# Transfer the data to the current GPU device
|
|
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
|
# Update architecture
|
|
if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH:
|
|
try:
|
|
inputsB, labelsB = next(trainB_iter)
|
|
except StopIteration:
|
|
trainB_iter = iter(train_loader[1])
|
|
inputsB, labelsB = next(trainB_iter)
|
|
inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True)
|
|
optimizer[1].zero_grad()
|
|
loss = m._loss(inputsB, labelsB)
|
|
loss.backward()
|
|
optimizer[1].step()
|
|
# Perform the forward pass
|
|
preds = model(inputs)
|
|
# Compute the loss
|
|
loss = loss_fun(preds, labels)
|
|
# Perform the backward pass
|
|
optimizer[0].zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm(model.parameters(), 5.0)
|
|
# Update the parameters
|
|
optimizer[0].step()
|
|
# Compute the errors
|
|
if cfg.TASK == "col":
|
|
preds = preds.permute(0, 2, 3, 1)
|
|
preds = preds.reshape(-1, preds.size(3))
|
|
labels = labels.reshape(-1)
|
|
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
|
else:
|
|
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
|
if cfg.TASK == "seg":
|
|
# top1_err is in fact inter; top5_err is in fact union
|
|
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
|
else:
|
|
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
|
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
|
# Combine the stats across the GPUs (no reduction if 1 GPU used)
|
|
loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
|
|
# Copy the stats from GPU to CPU (sync point)
|
|
loss = loss.item()
|
|
if cfg.TASK == "seg":
|
|
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
|
else:
|
|
top1_err, top5_err = top1_err.item(), top5_err.item()
|
|
train_meter.iter_toc()
|
|
# Update and log stats
|
|
train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
|
|
train_meter.log_iter_stats(cur_epoch, cur_iter)
|
|
train_meter.iter_tic()
|
|
# Log epoch stats
|
|
train_meter.log_epoch_stats(cur_epoch)
|
|
train_meter.reset()
|
|
# Log genotype
|
|
genotype = m.genotype()
|
|
logger.info("genotype = %s", genotype)
|
|
logger.info(F.softmax(m.net_.alphas_normal, dim=-1))
|
|
logger.info(F.softmax(m.net_.alphas_reduce, dim=-1))
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_epoch(test_loader, model, test_meter, cur_epoch):
|
|
"""Evaluates the model on the test set."""
|
|
# Enable eval mode
|
|
model.eval()
|
|
test_meter.iter_tic()
|
|
for cur_iter, (inputs, labels) in enumerate(test_loader):
|
|
# Transfer the data to the current GPU device
|
|
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
|
|
# Compute the predictions
|
|
preds = model(inputs)
|
|
# Compute the errors
|
|
if cfg.TASK == "col":
|
|
preds = preds.permute(0, 2, 3, 1)
|
|
preds = preds.reshape(-1, preds.size(3))
|
|
labels = labels.reshape(-1)
|
|
mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS
|
|
else:
|
|
mb_size = inputs.size(0) * cfg.NUM_GPUS
|
|
if cfg.TASK == "seg":
|
|
# top1_err is in fact inter; top5_err is in fact union
|
|
top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES)
|
|
else:
|
|
ks = [1, min(5, cfg.MODEL.NUM_CLASSES)] # rot only has 4 classes
|
|
top1_err, top5_err = meters.topk_errors(preds, labels, ks)
|
|
# Combine the errors across the GPUs (no reduction if 1 GPU used)
|
|
top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
|
|
# Copy the errors from GPU to CPU (sync point)
|
|
if cfg.TASK == "seg":
|
|
top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy()
|
|
else:
|
|
top1_err, top5_err = top1_err.item(), top5_err.item()
|
|
test_meter.iter_toc()
|
|
# Update and log stats
|
|
test_meter.update_stats(top1_err, top5_err, mb_size)
|
|
test_meter.log_iter_stats(cur_epoch, cur_iter)
|
|
test_meter.iter_tic()
|
|
# Log epoch stats
|
|
test_meter.log_epoch_stats(cur_epoch)
|
|
test_meter.reset()
|
|
|
|
|
|
def train_model():
|
|
"""Trains the model."""
|
|
# Setup training/testing environment
|
|
setup_env()
|
|
# Construct the model, loss_fun, and optimizer
|
|
model = setup_model()
|
|
loss_fun = builders.build_loss_fun().cuda()
|
|
if "search" in cfg.MODEL.TYPE:
|
|
params_w = [v for k, v in model.named_parameters() if "alphas" not in k]
|
|
params_a = [v for k, v in model.named_parameters() if "alphas" in k]
|
|
optimizer_w = torch.optim.SGD(
|
|
params=params_w,
|
|
lr=cfg.OPTIM.BASE_LR,
|
|
momentum=cfg.OPTIM.MOMENTUM,
|
|
weight_decay=cfg.OPTIM.WEIGHT_DECAY,
|
|
dampening=cfg.OPTIM.DAMPENING,
|
|
nesterov=cfg.OPTIM.NESTEROV
|
|
)
|
|
if cfg.OPTIM.ARCH_OPTIM == "adam":
|
|
optimizer_a = torch.optim.Adam(
|
|
params=params_a,
|
|
lr=cfg.OPTIM.ARCH_BASE_LR,
|
|
betas=(0.5, 0.999),
|
|
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY
|
|
)
|
|
elif cfg.OPTIM.ARCH_OPTIM == "sgd":
|
|
optimizer_a = torch.optim.SGD(
|
|
params=params_a,
|
|
lr=cfg.OPTIM.ARCH_BASE_LR,
|
|
momentum=cfg.OPTIM.MOMENTUM,
|
|
weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY,
|
|
dampening=cfg.OPTIM.DAMPENING,
|
|
nesterov=cfg.OPTIM.NESTEROV
|
|
)
|
|
optimizer = [optimizer_w, optimizer_a]
|
|
else:
|
|
optimizer = optim.construct_optimizer(model)
|
|
# Load checkpoint or initial weights
|
|
start_epoch = 0
|
|
if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
|
|
last_checkpoint = checkpoint.get_last_checkpoint()
|
|
checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
|
|
logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
|
|
start_epoch = checkpoint_epoch + 1
|
|
elif cfg.TRAIN.WEIGHTS:
|
|
checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
|
|
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
|
|
# Create data loaders and meters
|
|
if cfg.TRAIN.PORTION < 1:
|
|
if "search" in cfg.MODEL.TYPE:
|
|
train_loader = [loader._construct_loader(
|
|
dataset_name=cfg.TRAIN.DATASET,
|
|
split=cfg.TRAIN.SPLIT,
|
|
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
portion=cfg.TRAIN.PORTION,
|
|
side="l"
|
|
),
|
|
loader._construct_loader(
|
|
dataset_name=cfg.TRAIN.DATASET,
|
|
split=cfg.TRAIN.SPLIT,
|
|
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
portion=cfg.TRAIN.PORTION,
|
|
side="r"
|
|
)]
|
|
else:
|
|
train_loader = loader._construct_loader(
|
|
dataset_name=cfg.TRAIN.DATASET,
|
|
split=cfg.TRAIN.SPLIT,
|
|
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
portion=cfg.TRAIN.PORTION,
|
|
side="l"
|
|
)
|
|
test_loader = loader._construct_loader(
|
|
dataset_name=cfg.TRAIN.DATASET,
|
|
split=cfg.TRAIN.SPLIT,
|
|
batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
|
|
shuffle=False,
|
|
drop_last=False,
|
|
portion=cfg.TRAIN.PORTION,
|
|
side="r"
|
|
)
|
|
else:
|
|
train_loader = loader.construct_train_loader()
|
|
test_loader = loader.construct_test_loader()
|
|
train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter
|
|
test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter
|
|
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
|
train_meter = train_meter_type(len(l))
|
|
test_meter = test_meter_type(len(test_loader))
|
|
# Compute model and loader timings
|
|
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
|
|
l = train_loader[0] if isinstance(train_loader, list) else train_loader
|
|
benchmark.compute_time_full(model, loss_fun, l, test_loader)
|
|
# Perform the training loop
|
|
logger.info("Start epoch: {}".format(start_epoch + 1))
|
|
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
|
|
# Train for one epoch
|
|
f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch
|
|
f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
|
|
# Compute precise BN stats
|
|
if cfg.BN.USE_PRECISE_STATS:
|
|
net.compute_precise_bn_stats(model, train_loader)
|
|
# Save a checkpoint
|
|
if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
|
|
checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
|
|
logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
|
|
# Evaluate the model
|
|
next_epoch = cur_epoch + 1
|
|
if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
|
|
test_epoch(test_loader, model, test_meter, cur_epoch)
|
|
|
|
|
|
def test_model():
|
|
"""Evaluates a trained model."""
|
|
# Setup training/testing environment
|
|
setup_env()
|
|
# Construct the model
|
|
model = setup_model()
|
|
# Load model weights
|
|
checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
|
|
logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
|
|
# Create data loaders and meters
|
|
test_loader = loader.construct_test_loader()
|
|
test_meter = meters.TestMeter(len(test_loader))
|
|
# Evaluate the model
|
|
test_epoch(test_loader, model, test_meter, 0)
|
|
|
|
|
|
def time_model():
|
|
"""Times model and data loader."""
|
|
# Setup training/testing environment
|
|
setup_env()
|
|
# Construct the model and loss_fun
|
|
model = setup_model()
|
|
loss_fun = builders.build_loss_fun().cuda()
|
|
# Create data loaders
|
|
train_loader = loader.construct_train_loader()
|
|
test_loader = loader.construct_test_loader()
|
|
# Compute model and loader timings
|
|
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
|