naswot/pycls/core/benchmark.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

137 lines
5.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmarking functions."""
import pycls.core.logging as logging
import pycls.datasets.loader as loader
import torch
from pycls.core.config import cfg
from pycls.core.timer import Timer
logger = logging.get_logger(__name__)
@torch.no_grad()
def compute_time_eval(model):
"""Computes precise model forward test time using dummy data."""
# Use eval mode
model.eval()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
# Compute precise forward pass time
timer = Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
# Forward
timer.tic()
model(inputs)
torch.cuda.synchronize()
timer.toc()
return timer.average_time
def compute_time_train(model, loss_fun):
"""Computes precise model forward + backward time using dummy data."""
# Use train mode
model.train()
# Generate a dummy mini-batch and copy data to GPU
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
if cfg.TASK == "jig":
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
else:
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
if cfg.TASK in ['col', 'seg']:
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
else:
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
for cur_iter in range(total_iter):
# Reset the timers after the warmup phase
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
fw_timer.reset()
bw_timer.reset()
# Forward
fw_timer.tic()
preds = model(inputs)
if isinstance(preds, tuple):
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
preds = preds[0]
else:
loss = loss_fun(preds, labels)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
for bn, (mean, var) in zip(bns, bn_stats):
bn.running_mean, bn.running_var = mean, var
return fw_timer.average_time, bw_timer.average_time
def compute_time_loader(data_loader):
"""Computes loader time."""
timer = Timer()
loader.shuffle(data_loader, 0)
data_loader_iterator = iter(data_loader)
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
total_iter = min(total_iter, len(data_loader))
for cur_iter in range(total_iter):
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
timer.reset()
timer.tic()
next(data_loader_iterator)
timer.toc()
return timer.average_time
def compute_time_full(model, loss_fun, train_loader, test_loader):
"""Times model and data loader."""
logger.info("Computing model and loader timings...")
# Compute timings
test_fw_time = compute_time_eval(model)
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
train_fw_bw_time = train_fw_time + train_bw_time
train_loader_time = compute_time_loader(train_loader)
# Output iter timing
iter_times = {
"test_fw_time": test_fw_time,
"train_fw_time": train_fw_time,
"train_bw_time": train_bw_time,
"train_fw_bw_time": train_fw_bw_time,
"train_loader_time": train_loader_time,
}
logger.info(logging.dump_log_data(iter_times, "iter_times"))
# Output epoch timing
epoch_times = {
"test_fw_time": test_fw_time * len(test_loader),
"train_fw_time": train_fw_time * len(train_loader),
"train_bw_time": train_bw_time * len(train_loader),
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
"train_loader_time": train_loader_time * len(train_loader),
}
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))