137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Benchmarking functions."""
|
|
|
|
import pycls.core.logging as logging
|
|
import pycls.datasets.loader as loader
|
|
import torch
|
|
from pycls.core.config import cfg
|
|
from pycls.core.timer import Timer
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@torch.no_grad()
|
|
def compute_time_eval(model):
|
|
"""Computes precise model forward test time using dummy data."""
|
|
# Use eval mode
|
|
model.eval()
|
|
# Generate a dummy mini-batch and copy data to GPU
|
|
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
|
|
if cfg.TASK == "jig":
|
|
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
|
else:
|
|
inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
|
# Compute precise forward pass time
|
|
timer = Timer()
|
|
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
|
for cur_iter in range(total_iter):
|
|
# Reset the timers after the warmup phase
|
|
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
|
timer.reset()
|
|
# Forward
|
|
timer.tic()
|
|
model(inputs)
|
|
torch.cuda.synchronize()
|
|
timer.toc()
|
|
return timer.average_time
|
|
|
|
|
|
def compute_time_train(model, loss_fun):
|
|
"""Computes precise model forward + backward time using dummy data."""
|
|
# Use train mode
|
|
model.train()
|
|
# Generate a dummy mini-batch and copy data to GPU
|
|
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
|
|
if cfg.TASK == "jig":
|
|
inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
|
else:
|
|
inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False)
|
|
if cfg.TASK in ['col', 'seg']:
|
|
labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False)
|
|
else:
|
|
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
|
|
# Cache BatchNorm2D running stats
|
|
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
|
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
|
|
# Compute precise forward backward pass time
|
|
fw_timer, bw_timer = Timer(), Timer()
|
|
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
|
for cur_iter in range(total_iter):
|
|
# Reset the timers after the warmup phase
|
|
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
|
fw_timer.reset()
|
|
bw_timer.reset()
|
|
# Forward
|
|
fw_timer.tic()
|
|
preds = model(inputs)
|
|
if isinstance(preds, tuple):
|
|
loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels)
|
|
preds = preds[0]
|
|
else:
|
|
loss = loss_fun(preds, labels)
|
|
torch.cuda.synchronize()
|
|
fw_timer.toc()
|
|
# Backward
|
|
bw_timer.tic()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
bw_timer.toc()
|
|
# Restore BatchNorm2D running stats
|
|
for bn, (mean, var) in zip(bns, bn_stats):
|
|
bn.running_mean, bn.running_var = mean, var
|
|
return fw_timer.average_time, bw_timer.average_time
|
|
|
|
|
|
def compute_time_loader(data_loader):
|
|
"""Computes loader time."""
|
|
timer = Timer()
|
|
loader.shuffle(data_loader, 0)
|
|
data_loader_iterator = iter(data_loader)
|
|
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
|
|
total_iter = min(total_iter, len(data_loader))
|
|
for cur_iter in range(total_iter):
|
|
if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
|
|
timer.reset()
|
|
timer.tic()
|
|
next(data_loader_iterator)
|
|
timer.toc()
|
|
return timer.average_time
|
|
|
|
|
|
def compute_time_full(model, loss_fun, train_loader, test_loader):
|
|
"""Times model and data loader."""
|
|
logger.info("Computing model and loader timings...")
|
|
# Compute timings
|
|
test_fw_time = compute_time_eval(model)
|
|
train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
|
|
train_fw_bw_time = train_fw_time + train_bw_time
|
|
train_loader_time = compute_time_loader(train_loader)
|
|
# Output iter timing
|
|
iter_times = {
|
|
"test_fw_time": test_fw_time,
|
|
"train_fw_time": train_fw_time,
|
|
"train_bw_time": train_bw_time,
|
|
"train_fw_bw_time": train_fw_bw_time,
|
|
"train_loader_time": train_loader_time,
|
|
}
|
|
logger.info(logging.dump_log_data(iter_times, "iter_times"))
|
|
# Output epoch timing
|
|
epoch_times = {
|
|
"test_fw_time": test_fw_time * len(test_loader),
|
|
"train_fw_time": train_fw_time * len(train_loader),
|
|
"train_bw_time": train_bw_time * len(train_loader),
|
|
"train_fw_bw_time": train_fw_bw_time * len(train_loader),
|
|
"train_loader_time": train_loader_time * len(train_loader),
|
|
}
|
|
logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
|
|
# Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
|
|
overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
|
|
logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
|