#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Benchmarking functions.""" import pycls.core.logging as logging import pycls.datasets.loader as loader import torch from pycls.core.config import cfg from pycls.core.timer import Timer logger = logging.get_logger(__name__) @torch.no_grad() def compute_time_eval(model): """Computes precise model forward test time using dummy data.""" # Use eval mode model.eval() # Generate a dummy mini-batch and copy data to GPU im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) if cfg.TASK == "jig": inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) else: inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) # Compute precise forward pass time timer = Timer() total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER for cur_iter in range(total_iter): # Reset the timers after the warmup phase if cur_iter == cfg.PREC_TIME.WARMUP_ITER: timer.reset() # Forward timer.tic() model(inputs) torch.cuda.synchronize() timer.toc() return timer.average_time def compute_time_train(model, loss_fun): """Computes precise model forward + backward time using dummy data.""" # Use train mode model.train() # Generate a dummy mini-batch and copy data to GPU im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) if cfg.TASK == "jig": inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) else: inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) if cfg.TASK in ['col', 'seg']: labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False) else: labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) # Cache BatchNorm2D running stats bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] # Compute precise forward backward pass time fw_timer, bw_timer = Timer(), Timer() total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER for cur_iter in range(total_iter): # Reset the timers after the warmup phase if cur_iter == cfg.PREC_TIME.WARMUP_ITER: fw_timer.reset() bw_timer.reset() # Forward fw_timer.tic() preds = model(inputs) if isinstance(preds, tuple): loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels) preds = preds[0] else: loss = loss_fun(preds, labels) torch.cuda.synchronize() fw_timer.toc() # Backward bw_timer.tic() loss.backward() torch.cuda.synchronize() bw_timer.toc() # Restore BatchNorm2D running stats for bn, (mean, var) in zip(bns, bn_stats): bn.running_mean, bn.running_var = mean, var return fw_timer.average_time, bw_timer.average_time def compute_time_loader(data_loader): """Computes loader time.""" timer = Timer() loader.shuffle(data_loader, 0) data_loader_iterator = iter(data_loader) total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER total_iter = min(total_iter, len(data_loader)) for cur_iter in range(total_iter): if cur_iter == cfg.PREC_TIME.WARMUP_ITER: timer.reset() timer.tic() next(data_loader_iterator) timer.toc() return timer.average_time def compute_time_full(model, loss_fun, train_loader, test_loader): """Times model and data loader.""" logger.info("Computing model and loader timings...") # Compute timings test_fw_time = compute_time_eval(model) train_fw_time, train_bw_time = compute_time_train(model, loss_fun) train_fw_bw_time = train_fw_time + train_bw_time train_loader_time = compute_time_loader(train_loader) # Output iter timing iter_times = { "test_fw_time": test_fw_time, "train_fw_time": train_fw_time, "train_bw_time": train_bw_time, "train_fw_bw_time": train_fw_bw_time, "train_loader_time": train_loader_time, } logger.info(logging.dump_log_data(iter_times, "iter_times")) # Output epoch timing epoch_times = { "test_fw_time": test_fw_time * len(test_loader), "train_fw_time": train_fw_time * len(train_loader), "train_bw_time": train_bw_time * len(train_loader), "train_fw_bw_time": train_fw_bw_time * len(train_loader), "train_loader_time": train_loader_time * len(train_loader), } logger.info(logging.dump_log_data(epoch_times, "epoch_times")) # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))