v2
This commit is contained in:
		
							
								
								
									
										0
									
								
								pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										136
									
								
								pycls/core/benchmark.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								pycls/core/benchmark.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Benchmarking functions.""" | ||||
|  | ||||
| import pycls.core.logging as logging | ||||
| import pycls.datasets.loader as loader | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
| from pycls.core.timer import Timer | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def compute_time_eval(model): | ||||
|     """Computes precise model forward test time using dummy data.""" | ||||
|     # Use eval mode | ||||
|     model.eval() | ||||
|     # Generate a dummy mini-batch and copy data to GPU | ||||
|     im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) | ||||
|     if cfg.TASK == "jig": | ||||
|         inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) | ||||
|     else: | ||||
|         inputs = torch.zeros(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) | ||||
|     # Compute precise forward pass time | ||||
|     timer = Timer() | ||||
|     total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER | ||||
|     for cur_iter in range(total_iter): | ||||
|         # Reset the timers after the warmup phase | ||||
|         if cur_iter == cfg.PREC_TIME.WARMUP_ITER: | ||||
|             timer.reset() | ||||
|         # Forward | ||||
|         timer.tic() | ||||
|         model(inputs) | ||||
|         torch.cuda.synchronize() | ||||
|         timer.toc() | ||||
|     return timer.average_time | ||||
|  | ||||
|  | ||||
| def compute_time_train(model, loss_fun): | ||||
|     """Computes precise model forward + backward time using dummy data.""" | ||||
|     # Use train mode | ||||
|     model.train() | ||||
|     # Generate a dummy mini-batch and copy data to GPU | ||||
|     im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) | ||||
|     if cfg.TASK == "jig": | ||||
|         inputs = torch.rand(batch_size, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) | ||||
|     else: | ||||
|         inputs = torch.rand(batch_size, cfg.MODEL.INPUT_CHANNELS, im_size, im_size).cuda(non_blocking=False) | ||||
|     if cfg.TASK in ['col', 'seg']: | ||||
|         labels = torch.zeros(batch_size, im_size, im_size, dtype=torch.int64).cuda(non_blocking=False) | ||||
|     else: | ||||
|         labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) | ||||
|     # Cache BatchNorm2D running stats | ||||
|     bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] | ||||
|     bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] | ||||
|     # Compute precise forward backward pass time | ||||
|     fw_timer, bw_timer = Timer(), Timer() | ||||
|     total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER | ||||
|     for cur_iter in range(total_iter): | ||||
|         # Reset the timers after the warmup phase | ||||
|         if cur_iter == cfg.PREC_TIME.WARMUP_ITER: | ||||
|             fw_timer.reset() | ||||
|             bw_timer.reset() | ||||
|         # Forward | ||||
|         fw_timer.tic() | ||||
|         preds = model(inputs) | ||||
|         if isinstance(preds, tuple): | ||||
|             loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels) | ||||
|             preds = preds[0] | ||||
|         else: | ||||
|             loss = loss_fun(preds, labels) | ||||
|         torch.cuda.synchronize() | ||||
|         fw_timer.toc() | ||||
|         # Backward | ||||
|         bw_timer.tic() | ||||
|         loss.backward() | ||||
|         torch.cuda.synchronize() | ||||
|         bw_timer.toc() | ||||
|     # Restore BatchNorm2D running stats | ||||
|     for bn, (mean, var) in zip(bns, bn_stats): | ||||
|         bn.running_mean, bn.running_var = mean, var | ||||
|     return fw_timer.average_time, bw_timer.average_time | ||||
|  | ||||
|  | ||||
| def compute_time_loader(data_loader): | ||||
|     """Computes loader time.""" | ||||
|     timer = Timer() | ||||
|     loader.shuffle(data_loader, 0) | ||||
|     data_loader_iterator = iter(data_loader) | ||||
|     total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER | ||||
|     total_iter = min(total_iter, len(data_loader)) | ||||
|     for cur_iter in range(total_iter): | ||||
|         if cur_iter == cfg.PREC_TIME.WARMUP_ITER: | ||||
|             timer.reset() | ||||
|         timer.tic() | ||||
|         next(data_loader_iterator) | ||||
|         timer.toc() | ||||
|     return timer.average_time | ||||
|  | ||||
|  | ||||
| def compute_time_full(model, loss_fun, train_loader, test_loader): | ||||
|     """Times model and data loader.""" | ||||
|     logger.info("Computing model and loader timings...") | ||||
|     # Compute timings | ||||
|     test_fw_time = compute_time_eval(model) | ||||
|     train_fw_time, train_bw_time = compute_time_train(model, loss_fun) | ||||
|     train_fw_bw_time = train_fw_time + train_bw_time | ||||
|     train_loader_time = compute_time_loader(train_loader) | ||||
|     # Output iter timing | ||||
|     iter_times = { | ||||
|         "test_fw_time": test_fw_time, | ||||
|         "train_fw_time": train_fw_time, | ||||
|         "train_bw_time": train_bw_time, | ||||
|         "train_fw_bw_time": train_fw_bw_time, | ||||
|         "train_loader_time": train_loader_time, | ||||
|     } | ||||
|     logger.info(logging.dump_log_data(iter_times, "iter_times")) | ||||
|     # Output epoch timing | ||||
|     epoch_times = { | ||||
|         "test_fw_time": test_fw_time * len(test_loader), | ||||
|         "train_fw_time": train_fw_time * len(train_loader), | ||||
|         "train_bw_time": train_bw_time * len(train_loader), | ||||
|         "train_fw_bw_time": train_fw_bw_time * len(train_loader), | ||||
|         "train_loader_time": train_loader_time * len(train_loader), | ||||
|     } | ||||
|     logger.info(logging.dump_log_data(epoch_times, "epoch_times")) | ||||
|     # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) | ||||
|     overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time | ||||
|     logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) | ||||
							
								
								
									
										88
									
								
								pycls/core/builders.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								pycls/core/builders.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Model and loss construction functions.""" | ||||
|  | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
| from pycls.models.anynet import AnyNet | ||||
| from pycls.models.effnet import EffNet | ||||
| from pycls.models.regnet import RegNet | ||||
| from pycls.models.resnet import ResNet | ||||
| from pycls.models.nas.nas import NAS | ||||
| from pycls.models.nas.nas_search import NAS_Search | ||||
| from pycls.models.nas_bench.model_builder import NAS_Bench | ||||
|  | ||||
|  | ||||
| class LabelSmoothedCrossEntropyLoss(torch.nn.Module): | ||||
|     """CrossEntropyLoss with label smoothing.""" | ||||
|     def __init__(self): | ||||
|         super(LabelSmoothedCrossEntropyLoss, self).__init__() | ||||
|         self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS | ||||
|         self.num_classes = cfg.MODEL.NUM_CLASSES | ||||
|  | ||||
|     def forward(self, logits, target): | ||||
|         pred = logits.log_softmax(dim=-1) | ||||
|         with torch.no_grad(): | ||||
|             target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1) | ||||
|             target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps) | ||||
|         return (-target_dist * pred).sum(dim=-1).mean() | ||||
|  | ||||
|  | ||||
| # Supported models | ||||
| _models = { | ||||
|     "anynet": AnyNet, | ||||
|     "effnet": EffNet, | ||||
|     "resnet": ResNet, | ||||
|     "regnet": RegNet, | ||||
|     "nas": NAS, | ||||
|     "nas_search": NAS_Search, | ||||
|     "nas_bench": NAS_Bench, | ||||
| } | ||||
|  | ||||
| # Supported loss functions | ||||
| _loss_funs = { | ||||
|     "cross_entropy": torch.nn.CrossEntropyLoss, | ||||
|     "label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss, | ||||
| } | ||||
|  | ||||
|  | ||||
| def get_model(): | ||||
|     """Gets the model class specified in the config.""" | ||||
|     err_str = "Model type '{}' not supported" | ||||
|     assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE) | ||||
|     return _models[cfg.MODEL.TYPE] | ||||
|  | ||||
|  | ||||
| def get_loss_fun(): | ||||
|     """Gets the loss function class specified in the config.""" | ||||
|     err_str = "Loss function type '{}' not supported" | ||||
|     assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) | ||||
|     return _loss_funs[cfg.MODEL.LOSS_FUN] | ||||
|  | ||||
|  | ||||
| def build_model(): | ||||
|     """Builds the model.""" | ||||
|     return get_model()() | ||||
|  | ||||
|  | ||||
| def build_loss_fun(): | ||||
|     """Build the loss function.""" | ||||
|     if cfg.TASK == "seg": | ||||
|         return get_loss_fun()(ignore_index=255) | ||||
|     else: | ||||
|         return get_loss_fun()() | ||||
|  | ||||
|  | ||||
| def register_model(name, ctor): | ||||
|     """Registers a model dynamically.""" | ||||
|     _models[name] = ctor | ||||
|  | ||||
|  | ||||
| def register_loss_fun(name, ctor): | ||||
|     """Registers a loss function dynamically.""" | ||||
|     _loss_funs[name] = ctor | ||||
							
								
								
									
										98
									
								
								pycls/core/checkpoint.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								pycls/core/checkpoint.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,98 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Functions that handle saving and loading of checkpoints.""" | ||||
|  | ||||
| import os | ||||
|  | ||||
| import pycls.core.distributed as dist | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| # Common prefix for checkpoint file names | ||||
| _NAME_PREFIX = "model_epoch_" | ||||
| # Checkpoints directory name | ||||
| _DIR_NAME = "checkpoints" | ||||
|  | ||||
|  | ||||
| def get_checkpoint_dir(): | ||||
|     """Retrieves the location for storing checkpoints.""" | ||||
|     return os.path.join(cfg.OUT_DIR, _DIR_NAME) | ||||
|  | ||||
|  | ||||
| def get_checkpoint(epoch): | ||||
|     """Retrieves the path to a checkpoint file.""" | ||||
|     name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) | ||||
|     return os.path.join(get_checkpoint_dir(), name) | ||||
|  | ||||
|  | ||||
| def get_last_checkpoint(): | ||||
|     """Retrieves the most recent checkpoint (highest epoch number).""" | ||||
|     checkpoint_dir = get_checkpoint_dir() | ||||
|     # Checkpoint file names are in lexicographic order | ||||
|     checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] | ||||
|     last_checkpoint_name = sorted(checkpoints)[-1] | ||||
|     return os.path.join(checkpoint_dir, last_checkpoint_name) | ||||
|  | ||||
|  | ||||
| def has_checkpoint(): | ||||
|     """Determines if there are checkpoints available.""" | ||||
|     checkpoint_dir = get_checkpoint_dir() | ||||
|     if not os.path.exists(checkpoint_dir): | ||||
|         return False | ||||
|     return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) | ||||
|  | ||||
|  | ||||
| def save_checkpoint(model, optimizer, epoch): | ||||
|     """Saves a checkpoint.""" | ||||
|     # Save checkpoints only from the master process | ||||
|     if not dist.is_master_proc(): | ||||
|         return | ||||
|     # Ensure that the checkpoint dir exists | ||||
|     os.makedirs(get_checkpoint_dir(), exist_ok=True) | ||||
|     # Omit the DDP wrapper in the multi-gpu setting | ||||
|     sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() | ||||
|     # Record the state | ||||
|     if isinstance(optimizer, list): | ||||
|         checkpoint = { | ||||
|             "epoch": epoch, | ||||
|             "model_state": sd, | ||||
|             "optimizer_w_state": optimizer[0].state_dict(), | ||||
|             "optimizer_a_state": optimizer[1].state_dict(), | ||||
|             "cfg": cfg.dump(), | ||||
|         } | ||||
|     else: | ||||
|         checkpoint = { | ||||
|             "epoch": epoch, | ||||
|             "model_state": sd, | ||||
|             "optimizer_state": optimizer.state_dict(), | ||||
|             "cfg": cfg.dump(), | ||||
|         } | ||||
|     # Write the checkpoint | ||||
|     checkpoint_file = get_checkpoint(epoch + 1) | ||||
|     torch.save(checkpoint, checkpoint_file) | ||||
|     return checkpoint_file | ||||
|  | ||||
|  | ||||
| def load_checkpoint(checkpoint_file, model, optimizer=None): | ||||
|     """Loads the checkpoint from the given file.""" | ||||
|     err_str = "Checkpoint '{}' not found" | ||||
|     assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) | ||||
|     # Load the checkpoint on CPU to avoid GPU mem spike | ||||
|     checkpoint = torch.load(checkpoint_file, map_location="cpu") | ||||
|     # Account for the DDP wrapper in the multi-gpu setting | ||||
|     ms = model.module if cfg.NUM_GPUS > 1 else model | ||||
|     ms.load_state_dict(checkpoint["model_state"]) | ||||
|     # Load the optimizer state (commonly not done when fine-tuning) | ||||
|     if optimizer: | ||||
|         if isinstance(optimizer, list): | ||||
|             optimizer[0].load_state_dict(checkpoint["optimizer_w_state"]) | ||||
|             optimizer[1].load_state_dict(checkpoint["optimizer_a_state"]) | ||||
|         else: | ||||
|             optimizer.load_state_dict(checkpoint["optimizer_state"]) | ||||
|     return checkpoint["epoch"] | ||||
							
								
								
									
										500
									
								
								pycls/core/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										500
									
								
								pycls/core/config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,500 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Configuration file (powered by YACS).""" | ||||
|  | ||||
| import argparse | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| from pycls.core.io import cache_url | ||||
| from yacs.config import CfgNode as CfgNode | ||||
|  | ||||
|  | ||||
| # Global config object | ||||
| _C = CfgNode() | ||||
|  | ||||
| # Example usage: | ||||
| #   from core.config import cfg | ||||
| cfg = _C | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Model options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.MODEL = CfgNode() | ||||
|  | ||||
| # Model type | ||||
| _C.MODEL.TYPE = "" | ||||
|  | ||||
| # Number of weight layers | ||||
| _C.MODEL.DEPTH = 0 | ||||
|  | ||||
| # Number of input channels | ||||
| _C.MODEL.INPUT_CHANNELS = 3 | ||||
|  | ||||
| # Number of classes | ||||
| _C.MODEL.NUM_CLASSES = 10 | ||||
|  | ||||
| # Loss function (see pycls/core/builders.py for options) | ||||
| _C.MODEL.LOSS_FUN = "cross_entropy" | ||||
|  | ||||
| # Label smoothing eps | ||||
| _C.MODEL.LABEL_SMOOTHING_EPS = 0.0 | ||||
|  | ||||
| # ASPP channels | ||||
| _C.MODEL.ASPP_CHANNELS = 256 | ||||
|  | ||||
| # ASPP dilation rates | ||||
| _C.MODEL.ASPP_RATES = [6, 12, 18] | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # ResNet options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.RESNET = CfgNode() | ||||
|  | ||||
| # Transformation function (see pycls/models/resnet.py for options) | ||||
| _C.RESNET.TRANS_FUN = "basic_transform" | ||||
|  | ||||
| # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) | ||||
| _C.RESNET.NUM_GROUPS = 1 | ||||
|  | ||||
| # Width of each group (64 -> ResNet; 4 -> ResNeXt) | ||||
| _C.RESNET.WIDTH_PER_GROUP = 64 | ||||
|  | ||||
| # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) | ||||
| _C.RESNET.STRIDE_1X1 = True | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # AnyNet options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.ANYNET = CfgNode() | ||||
|  | ||||
| # Stem type | ||||
| _C.ANYNET.STEM_TYPE = "simple_stem_in" | ||||
|  | ||||
| # Stem width | ||||
| _C.ANYNET.STEM_W = 32 | ||||
|  | ||||
| # Block type | ||||
| _C.ANYNET.BLOCK_TYPE = "res_bottleneck_block" | ||||
|  | ||||
| # Depth for each stage (number of blocks in the stage) | ||||
| _C.ANYNET.DEPTHS = [] | ||||
|  | ||||
| # Width for each stage (width of each block in the stage) | ||||
| _C.ANYNET.WIDTHS = [] | ||||
|  | ||||
| # Strides for each stage (applies to the first block of each stage) | ||||
| _C.ANYNET.STRIDES = [] | ||||
|  | ||||
| # Bottleneck multipliers for each stage (applies to bottleneck block) | ||||
| _C.ANYNET.BOT_MULS = [] | ||||
|  | ||||
| # Group widths for each stage (applies to bottleneck block) | ||||
| _C.ANYNET.GROUP_WS = [] | ||||
|  | ||||
| # Whether SE is enabled for res_bottleneck_block | ||||
| _C.ANYNET.SE_ON = False | ||||
|  | ||||
| # SE ratio | ||||
| _C.ANYNET.SE_R = 0.25 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # RegNet options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.REGNET = CfgNode() | ||||
|  | ||||
| # Stem type | ||||
| _C.REGNET.STEM_TYPE = "simple_stem_in" | ||||
|  | ||||
| # Stem width | ||||
| _C.REGNET.STEM_W = 32 | ||||
|  | ||||
| # Block type | ||||
| _C.REGNET.BLOCK_TYPE = "res_bottleneck_block" | ||||
|  | ||||
| # Stride of each stage | ||||
| _C.REGNET.STRIDE = 2 | ||||
|  | ||||
| # Squeeze-and-Excitation (RegNetY) | ||||
| _C.REGNET.SE_ON = False | ||||
| _C.REGNET.SE_R = 0.25 | ||||
|  | ||||
| # Depth | ||||
| _C.REGNET.DEPTH = 10 | ||||
|  | ||||
| # Initial width | ||||
| _C.REGNET.W0 = 32 | ||||
|  | ||||
| # Slope | ||||
| _C.REGNET.WA = 5.0 | ||||
|  | ||||
| # Quantization | ||||
| _C.REGNET.WM = 2.5 | ||||
|  | ||||
| # Group width | ||||
| _C.REGNET.GROUP_W = 16 | ||||
|  | ||||
| # Bottleneck multiplier (bm = 1 / b from the paper) | ||||
| _C.REGNET.BOT_MUL = 1.0 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # EfficientNet options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.EN = CfgNode() | ||||
|  | ||||
| # Stem width | ||||
| _C.EN.STEM_W = 32 | ||||
|  | ||||
| # Depth for each stage (number of blocks in the stage) | ||||
| _C.EN.DEPTHS = [] | ||||
|  | ||||
| # Width for each stage (width of each block in the stage) | ||||
| _C.EN.WIDTHS = [] | ||||
|  | ||||
| # Expansion ratios for MBConv blocks in each stage | ||||
| _C.EN.EXP_RATIOS = [] | ||||
|  | ||||
| # Squeeze-and-Excitation (SE) ratio | ||||
| _C.EN.SE_R = 0.25 | ||||
|  | ||||
| # Strides for each stage (applies to the first block of each stage) | ||||
| _C.EN.STRIDES = [] | ||||
|  | ||||
| # Kernel sizes for each stage | ||||
| _C.EN.KERNELS = [] | ||||
|  | ||||
| # Head width | ||||
| _C.EN.HEAD_W = 1280 | ||||
|  | ||||
| # Drop connect ratio | ||||
| _C.EN.DC_RATIO = 0.0 | ||||
|  | ||||
| # Dropout ratio | ||||
| _C.EN.DROPOUT_RATIO = 0.0 | ||||
|  | ||||
|  | ||||
| # ---------------------------------------------------------------------------- # | ||||
| # NAS options | ||||
| # ---------------------------------------------------------------------------- # | ||||
| _C.NAS = CfgNode() | ||||
|  | ||||
| # Cell genotype | ||||
| _C.NAS.GENOTYPE = 'nas' | ||||
|  | ||||
| # Custom genotype | ||||
| _C.NAS.CUSTOM_GENOTYPE = [] | ||||
|  | ||||
| # Base NAS width | ||||
| _C.NAS.WIDTH = 16 | ||||
|  | ||||
| # Total number of cells | ||||
| _C.NAS.DEPTH = 20 | ||||
|  | ||||
| # Auxiliary heads | ||||
| _C.NAS.AUX = False | ||||
|  | ||||
| # Weight for auxiliary heads | ||||
| _C.NAS.AUX_WEIGHT = 0.4 | ||||
|  | ||||
| # Drop path probability | ||||
| _C.NAS.DROP_PROB = 0.0 | ||||
|  | ||||
| # Matrix in NAS Bench | ||||
| _C.NAS.MATRIX = [] | ||||
|  | ||||
| # Operations in NAS Bench | ||||
| _C.NAS.OPS = [] | ||||
|  | ||||
| # Number of stacks in NAS Bench | ||||
| _C.NAS.NUM_STACKS = 3 | ||||
|  | ||||
| # Number of modules per stack in NAS Bench | ||||
| _C.NAS.NUM_MODULES_PER_STACK = 3 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Batch norm options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.BN = CfgNode() | ||||
|  | ||||
| # BN epsilon | ||||
| _C.BN.EPS = 1e-5 | ||||
|  | ||||
| # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) | ||||
| _C.BN.MOM = 0.1 | ||||
|  | ||||
| # Precise BN stats | ||||
| _C.BN.USE_PRECISE_STATS = False | ||||
| _C.BN.NUM_SAMPLES_PRECISE = 1024 | ||||
|  | ||||
| # Initialize the gamma of the final BN of each block to zero | ||||
| _C.BN.ZERO_INIT_FINAL_GAMMA = False | ||||
|  | ||||
| # Use a different weight decay for BN layers | ||||
| _C.BN.USE_CUSTOM_WEIGHT_DECAY = False | ||||
| _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Optimizer options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.OPTIM = CfgNode() | ||||
|  | ||||
| # Base learning rate | ||||
| _C.OPTIM.BASE_LR = 0.1 | ||||
|  | ||||
| # Learning rate policy select from {'cos', 'exp', 'steps'} | ||||
| _C.OPTIM.LR_POLICY = "cos" | ||||
|  | ||||
| # Exponential decay factor | ||||
| _C.OPTIM.GAMMA = 0.1 | ||||
|  | ||||
| # Steps for 'steps' policy (in epochs) | ||||
| _C.OPTIM.STEPS = [] | ||||
|  | ||||
| # Learning rate multiplier for 'steps' policy | ||||
| _C.OPTIM.LR_MULT = 0.1 | ||||
|  | ||||
| # Maximal number of epochs | ||||
| _C.OPTIM.MAX_EPOCH = 200 | ||||
|  | ||||
| # Momentum | ||||
| _C.OPTIM.MOMENTUM = 0.9 | ||||
|  | ||||
| # Momentum dampening | ||||
| _C.OPTIM.DAMPENING = 0.0 | ||||
|  | ||||
| # Nesterov momentum | ||||
| _C.OPTIM.NESTEROV = True | ||||
|  | ||||
| # L2 regularization | ||||
| _C.OPTIM.WEIGHT_DECAY = 5e-4 | ||||
|  | ||||
| # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR | ||||
| _C.OPTIM.WARMUP_FACTOR = 0.1 | ||||
|  | ||||
| # Gradually warm up the OPTIM.BASE_LR over this number of epochs | ||||
| _C.OPTIM.WARMUP_EPOCHS = 0 | ||||
|  | ||||
| # Update the learning rate per iter | ||||
| _C.OPTIM.ITER_LR = False | ||||
|  | ||||
| # Base learning rate for arch | ||||
| _C.OPTIM.ARCH_BASE_LR = 0.0003 | ||||
|  | ||||
| # L2 regularization for arch | ||||
| _C.OPTIM.ARCH_WEIGHT_DECAY = 0.001 | ||||
|  | ||||
| # Optimizer for arch | ||||
| _C.OPTIM.ARCH_OPTIM = 'adam' | ||||
|  | ||||
| # Epoch to start optimizing arch | ||||
| _C.OPTIM.ARCH_EPOCH = 0.0 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Training options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.TRAIN = CfgNode() | ||||
|  | ||||
| # Dataset and split | ||||
| _C.TRAIN.DATASET = "" | ||||
| _C.TRAIN.SPLIT = "train" | ||||
|  | ||||
| # Total mini-batch size | ||||
| _C.TRAIN.BATCH_SIZE = 128 | ||||
|  | ||||
| # Image size | ||||
| _C.TRAIN.IM_SIZE = 224 | ||||
|  | ||||
| # Evaluate model on test data every eval period epochs | ||||
| _C.TRAIN.EVAL_PERIOD = 1 | ||||
|  | ||||
| # Save model checkpoint every checkpoint period epochs | ||||
| _C.TRAIN.CHECKPOINT_PERIOD = 1 | ||||
|  | ||||
| # Resume training from the latest checkpoint in the output directory | ||||
| _C.TRAIN.AUTO_RESUME = True | ||||
|  | ||||
| # Weights to start training from | ||||
| _C.TRAIN.WEIGHTS = "" | ||||
|  | ||||
| # Percentage of gray images in jig | ||||
| _C.TRAIN.GRAY_PERCENTAGE = 0.0 | ||||
|  | ||||
| # Portion to create trainA/trainB split | ||||
| _C.TRAIN.PORTION = 1.0 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Testing options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.TEST = CfgNode() | ||||
|  | ||||
| # Dataset and split | ||||
| _C.TEST.DATASET = "" | ||||
| _C.TEST.SPLIT = "val" | ||||
|  | ||||
| # Total mini-batch size | ||||
| _C.TEST.BATCH_SIZE = 200 | ||||
|  | ||||
| # Image size | ||||
| _C.TEST.IM_SIZE = 256 | ||||
|  | ||||
| # Weights to use for testing | ||||
| _C.TEST.WEIGHTS = "" | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Common train/test data loader options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.DATA_LOADER = CfgNode() | ||||
|  | ||||
| # Number of data loader workers per process | ||||
| _C.DATA_LOADER.NUM_WORKERS = 8 | ||||
|  | ||||
| # Load data to pinned host memory | ||||
| _C.DATA_LOADER.PIN_MEMORY = True | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Memory options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.MEM = CfgNode() | ||||
|  | ||||
| # Perform ReLU inplace | ||||
| _C.MEM.RELU_INPLACE = True | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # CUDNN options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.CUDNN = CfgNode() | ||||
|  | ||||
| # Perform benchmarking to select the fastest CUDNN algorithms to use | ||||
| # Note that this may increase the memory usage and will likely not result | ||||
| # in overall speedups when variable size inputs are used (e.g. COCO training) | ||||
| _C.CUDNN.BENCHMARK = True | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Precise timing options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| _C.PREC_TIME = CfgNode() | ||||
|  | ||||
| # Number of iterations to warm up the caches | ||||
| _C.PREC_TIME.WARMUP_ITER = 3 | ||||
|  | ||||
| # Number of iterations to compute avg time | ||||
| _C.PREC_TIME.NUM_ITER = 30 | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Misc options | ||||
| # ------------------------------------------------------------------------------------ # | ||||
|  | ||||
| # Number of GPUs to use (applies to both training and testing) | ||||
| _C.NUM_GPUS = 1 | ||||
|  | ||||
| # Task (cls, seg, rot, col, jig) | ||||
| _C.TASK = "cls" | ||||
|  | ||||
| # Grid in Jigsaw (2, 3); no effect if TASK is not jig | ||||
| _C.JIGSAW_GRID = 3 | ||||
|  | ||||
| # Output directory | ||||
| _C.OUT_DIR = "/tmp" | ||||
|  | ||||
| # Config destination (in OUT_DIR) | ||||
| _C.CFG_DEST = "config.yaml" | ||||
|  | ||||
| # Note that non-determinism may still be present due to non-deterministic | ||||
| # operator implementations in GPU operator libraries | ||||
| _C.RNG_SEED = 1 | ||||
|  | ||||
| # Log destination ('stdout' or 'file') | ||||
| _C.LOG_DEST = "stdout" | ||||
|  | ||||
| # Log period in iters | ||||
| _C.LOG_PERIOD = 10 | ||||
|  | ||||
| # Distributed backend | ||||
| _C.DIST_BACKEND = "nccl" | ||||
|  | ||||
| # Hostname and port for initializing multi-process groups | ||||
| _C.HOST = "localhost" | ||||
| _C.PORT = 10001 | ||||
|  | ||||
| # Models weights referred to by URL are downloaded to this local cache | ||||
| _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache" | ||||
|  | ||||
|  | ||||
| # ------------------------------------------------------------------------------------ # | ||||
| # Deprecated keys | ||||
| # ------------------------------------------------------------------------------------ # | ||||
|  | ||||
| _C.register_deprecated_key("PREC_TIME.BATCH_SIZE") | ||||
| _C.register_deprecated_key("PREC_TIME.ENABLED") | ||||
|  | ||||
|  | ||||
| def assert_and_infer_cfg(cache_urls=True): | ||||
|     """Checks config values invariants.""" | ||||
|     err_str = "The first lr step must start at 0" | ||||
|     assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str | ||||
|     data_splits = ["train", "val", "test"] | ||||
|     err_str = "Data split '{}' not supported" | ||||
|     assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT) | ||||
|     assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT) | ||||
|     err_str = "Mini-batch size should be a multiple of NUM_GPUS." | ||||
|     assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str | ||||
|     assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str | ||||
|     err_str = "Precise BN stats computation not verified for > 1 GPU" | ||||
|     assert not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1, err_str | ||||
|     err_str = "Log destination '{}' not supported" | ||||
|     assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) | ||||
|     if cache_urls: | ||||
|         cache_cfg_urls() | ||||
|  | ||||
|  | ||||
| def cache_cfg_urls(): | ||||
|     """Download URLs in config, cache them, and rewrite cfg to use cached file.""" | ||||
|     _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE) | ||||
|     _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) | ||||
|  | ||||
|  | ||||
| def dump_cfg(): | ||||
|     """Dumps the config to the output directory.""" | ||||
|     cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) | ||||
|     with open(cfg_file, "w") as f: | ||||
|         _C.dump(stream=f) | ||||
|  | ||||
|  | ||||
| def load_cfg(out_dir, cfg_dest="config.yaml"): | ||||
|     """Loads config from specified output directory.""" | ||||
|     cfg_file = os.path.join(out_dir, cfg_dest) | ||||
|     _C.merge_from_file(cfg_file) | ||||
|  | ||||
|  | ||||
| def load_cfg_fom_args(description="Config file options."): | ||||
|     """Load config from command line arguments and set any specified options.""" | ||||
|     parser = argparse.ArgumentParser(description=description) | ||||
|     help_s = "Config file location" | ||||
|     parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str) | ||||
|     help_s = "See pycls/core/config.py for all options" | ||||
|     parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER) | ||||
|     if len(sys.argv) == 1: | ||||
|         parser.print_help() | ||||
|         sys.exit(1) | ||||
|     args = parser.parse_args() | ||||
|     _C.merge_from_file(args.cfg_file) | ||||
|     _C.merge_from_list(args.opts) | ||||
							
								
								
									
										157
									
								
								pycls/core/distributed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								pycls/core/distributed.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,157 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Distributed helpers.""" | ||||
|  | ||||
| import multiprocessing | ||||
| import os | ||||
| import signal | ||||
| import threading | ||||
| import traceback | ||||
|  | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def is_master_proc(): | ||||
|     """Determines if the current process is the master process. | ||||
|  | ||||
|     Master process is responsible for logging, writing and loading checkpoints. In | ||||
|     the multi GPU setting, we assign the master role to the rank 0 process. When | ||||
|     training using a single GPU, there is a single process which is considered master. | ||||
|     """ | ||||
|     return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 | ||||
|  | ||||
|  | ||||
| def init_process_group(proc_rank, world_size): | ||||
|     """Initializes the default process group.""" | ||||
|     # Set the GPU to use | ||||
|     torch.cuda.set_device(proc_rank) | ||||
|     # Initialize the process group | ||||
|     torch.distributed.init_process_group( | ||||
|         backend=cfg.DIST_BACKEND, | ||||
|         init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT), | ||||
|         world_size=world_size, | ||||
|         rank=proc_rank, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def destroy_process_group(): | ||||
|     """Destroys the default process group.""" | ||||
|     torch.distributed.destroy_process_group() | ||||
|  | ||||
|  | ||||
| def scaled_all_reduce(tensors): | ||||
|     """Performs the scaled all_reduce operation on the provided tensors. | ||||
|  | ||||
|     The input tensors are modified in-place. Currently supports only the sum | ||||
|     reduction operator. The reduced values are scaled by the inverse size of the | ||||
|     process group (equivalent to cfg.NUM_GPUS). | ||||
|     """ | ||||
|     # There is no need for reduction in the single-proc case | ||||
|     if cfg.NUM_GPUS == 1: | ||||
|         return tensors | ||||
|     # Queue the reductions | ||||
|     reductions = [] | ||||
|     for tensor in tensors: | ||||
|         reduction = torch.distributed.all_reduce(tensor, async_op=True) | ||||
|         reductions.append(reduction) | ||||
|     # Wait for reductions to finish | ||||
|     for reduction in reductions: | ||||
|         reduction.wait() | ||||
|     # Scale the results | ||||
|     for tensor in tensors: | ||||
|         tensor.mul_(1.0 / cfg.NUM_GPUS) | ||||
|     return tensors | ||||
|  | ||||
|  | ||||
| class ChildException(Exception): | ||||
|     """Wraps an exception from a child process.""" | ||||
|  | ||||
|     def __init__(self, child_trace): | ||||
|         super(ChildException, self).__init__(child_trace) | ||||
|  | ||||
|  | ||||
| class ErrorHandler(object): | ||||
|     """Multiprocessing error handler (based on fairseq's). | ||||
|  | ||||
|     Listens for errors in child processes and propagates the tracebacks to the parent. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, error_queue): | ||||
|         # Shared error queue | ||||
|         self.error_queue = error_queue | ||||
|         # Children processes sharing the error queue | ||||
|         self.children_pids = [] | ||||
|         # Start a thread listening to errors | ||||
|         self.error_listener = threading.Thread(target=self.listen, daemon=True) | ||||
|         self.error_listener.start() | ||||
|         # Register the signal handler | ||||
|         signal.signal(signal.SIGUSR1, self.signal_handler) | ||||
|  | ||||
|     def add_child(self, pid): | ||||
|         """Registers a child process.""" | ||||
|         self.children_pids.append(pid) | ||||
|  | ||||
|     def listen(self): | ||||
|         """Listens for errors in the error queue.""" | ||||
|         # Wait until there is an error in the queue | ||||
|         child_trace = self.error_queue.get() | ||||
|         # Put the error back for the signal handler | ||||
|         self.error_queue.put(child_trace) | ||||
|         # Invoke the signal handler | ||||
|         os.kill(os.getpid(), signal.SIGUSR1) | ||||
|  | ||||
|     def signal_handler(self, _sig_num, _stack_frame): | ||||
|         """Signal handler.""" | ||||
|         # Kill children processes | ||||
|         for pid in self.children_pids: | ||||
|             os.kill(pid, signal.SIGINT) | ||||
|         # Propagate the error from the child process | ||||
|         raise ChildException(self.error_queue.get()) | ||||
|  | ||||
|  | ||||
| def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs): | ||||
|     """Runs a function from a child process.""" | ||||
|     try: | ||||
|         # Initialize the process group | ||||
|         init_process_group(proc_rank, world_size) | ||||
|         # Run the function | ||||
|         fun(*fun_args, **fun_kwargs) | ||||
|     except KeyboardInterrupt: | ||||
|         # Killed by the parent process | ||||
|         pass | ||||
|     except Exception: | ||||
|         # Propagate exception to the parent process | ||||
|         error_queue.put(traceback.format_exc()) | ||||
|     finally: | ||||
|         # Destroy the process group | ||||
|         destroy_process_group() | ||||
|  | ||||
|  | ||||
| def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): | ||||
|     """Runs a function in a multi-proc setting (unless num_proc == 1).""" | ||||
|     # There is no need for multi-proc in the single-proc case | ||||
|     fun_kwargs = fun_kwargs if fun_kwargs else {} | ||||
|     if num_proc == 1: | ||||
|         fun(*fun_args, **fun_kwargs) | ||||
|         return | ||||
|     # Handle errors from training subprocesses | ||||
|     error_queue = multiprocessing.SimpleQueue() | ||||
|     error_handler = ErrorHandler(error_queue) | ||||
|     # Run each training subprocess | ||||
|     ps = [] | ||||
|     for i in range(num_proc): | ||||
|         p_i = multiprocessing.Process( | ||||
|             target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs) | ||||
|         ) | ||||
|         ps.append(p_i) | ||||
|         p_i.start() | ||||
|         error_handler.add_child(p_i.pid) | ||||
|     # Wait for each subprocess to finish | ||||
|     for p in ps: | ||||
|         p.join() | ||||
							
								
								
									
										77
									
								
								pycls/core/io.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								pycls/core/io.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """IO utilities (adapted from Detectron)""" | ||||
|  | ||||
| import logging | ||||
| import os | ||||
| import re | ||||
| import sys | ||||
| from urllib import request as urlrequest | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" | ||||
|  | ||||
|  | ||||
| def cache_url(url_or_file, cache_dir): | ||||
|     """Download the file specified by the URL to the cache_dir and return the path to | ||||
|     the cached file. If the argument is not a URL, simply return it as is. | ||||
|     """ | ||||
|     is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None | ||||
|     if not is_url: | ||||
|         return url_or_file | ||||
|     url = url_or_file | ||||
|     err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}" | ||||
|     assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL) | ||||
|     cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) | ||||
|     if os.path.exists(cache_file_path): | ||||
|         return cache_file_path | ||||
|     cache_file_dir = os.path.dirname(cache_file_path) | ||||
|     if not os.path.exists(cache_file_dir): | ||||
|         os.makedirs(cache_file_dir) | ||||
|     logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) | ||||
|     download_url(url, cache_file_path) | ||||
|     return cache_file_path | ||||
|  | ||||
|  | ||||
| def _progress_bar(count, total): | ||||
|     """Report download progress. Credit: | ||||
|     https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 | ||||
|     """ | ||||
|     bar_len = 60 | ||||
|     filled_len = int(round(bar_len * count / float(total))) | ||||
|     percents = round(100.0 * count / float(total), 1) | ||||
|     bar = "=" * filled_len + "-" * (bar_len - filled_len) | ||||
|     sys.stdout.write( | ||||
|         "  [{}] {}% of {:.1f}MB file  \r".format(bar, percents, total / 1024 / 1024) | ||||
|     ) | ||||
|     sys.stdout.flush() | ||||
|     if count >= total: | ||||
|         sys.stdout.write("\n") | ||||
|  | ||||
|  | ||||
| def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): | ||||
|     """Download url and write it to dst_file_path. Credit: | ||||
|     https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook | ||||
|     """ | ||||
|     req = urlrequest.Request(url) | ||||
|     response = urlrequest.urlopen(req) | ||||
|     total_size = response.info().get("Content-Length").strip() | ||||
|     total_size = int(total_size) | ||||
|     bytes_so_far = 0 | ||||
|     with open(dst_file_path, "wb") as f: | ||||
|         while 1: | ||||
|             chunk = response.read(chunk_size) | ||||
|             bytes_so_far += len(chunk) | ||||
|             if not chunk: | ||||
|                 break | ||||
|             if progress_hook: | ||||
|                 progress_hook(bytes_so_far, total_size) | ||||
|             f.write(chunk) | ||||
|     return bytes_so_far | ||||
							
								
								
									
										138
									
								
								pycls/core/logging.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								pycls/core/logging.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Logging.""" | ||||
|  | ||||
| import builtins | ||||
| import decimal | ||||
| import logging | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| import pycls.core.distributed as dist | ||||
| import simplejson | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| # Show filename and line number in logs | ||||
| _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" | ||||
|  | ||||
| # Log file name (for cfg.LOG_DEST = 'file') | ||||
| _LOG_FILE = "stdout.log" | ||||
|  | ||||
| # Data output with dump_log_data(data, data_type) will be tagged w/ this | ||||
| _TAG = "json_stats: " | ||||
|  | ||||
| # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type | ||||
| _TYPE = "_type" | ||||
|  | ||||
|  | ||||
| def _suppress_print(): | ||||
|     """Suppresses printing from the current process.""" | ||||
|  | ||||
|     def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): | ||||
|         pass | ||||
|  | ||||
|     builtins.print = ignore | ||||
|  | ||||
|  | ||||
| def setup_logging(): | ||||
|     """Sets up the logging.""" | ||||
|     # Enable logging only for the master process | ||||
|     if dist.is_master_proc(): | ||||
|         # Clear the root logger to prevent any existing logging config | ||||
|         # (e.g. set by another module) from messing with our setup | ||||
|         logging.root.handlers = [] | ||||
|         # Construct logging configuration | ||||
|         logging_config = {"level": logging.INFO, "format": _FORMAT} | ||||
|         # Log either to stdout or to a file | ||||
|         if cfg.LOG_DEST == "stdout": | ||||
|             logging_config["stream"] = sys.stdout | ||||
|         else: | ||||
|             logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) | ||||
|         # Configure logging | ||||
|         logging.basicConfig(**logging_config) | ||||
|     else: | ||||
|         _suppress_print() | ||||
|  | ||||
|  | ||||
| def get_logger(name): | ||||
|     """Retrieves the logger.""" | ||||
|     return logging.getLogger(name) | ||||
|  | ||||
|  | ||||
| def dump_log_data(data, data_type, prec=4): | ||||
|     """Covert data (a dictionary) into tagged json string for logging.""" | ||||
|     data[_TYPE] = data_type | ||||
|     data = float_to_decimal(data, prec) | ||||
|     data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True) | ||||
|     return "{:s}{:s}".format(_TAG, data_json) | ||||
|  | ||||
|  | ||||
| def float_to_decimal(data, prec=4): | ||||
|     """Convert floats to decimals which allows for fixed width json.""" | ||||
|     if isinstance(data, dict): | ||||
|         return {k: float_to_decimal(v, prec) for k, v in data.items()} | ||||
|     if isinstance(data, float): | ||||
|         return decimal.Decimal(("{:." + str(prec) + "f}").format(data)) | ||||
|     else: | ||||
|         return data | ||||
|  | ||||
|  | ||||
| def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): | ||||
|     """Get all log files in directory containing subdirs of trained models.""" | ||||
|     names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] | ||||
|     files = [os.path.join(log_dir, n, log_file) for n in names] | ||||
|     f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] | ||||
|     files, names = zip(*f_n_ps) if f_n_ps else ([], []) | ||||
|     return files, names | ||||
|  | ||||
|  | ||||
| def load_log_data(log_file, data_types_to_skip=()): | ||||
|     """Loads log data into a dictionary of the form data[data_type][metric][index].""" | ||||
|     # Load log_file | ||||
|     assert os.path.exists(log_file), "Log file not found: {}".format(log_file) | ||||
|     with open(log_file, "r") as f: | ||||
|         lines = f.readlines() | ||||
|     # Extract and parse lines that start with _TAG and have a type specified | ||||
|     lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] | ||||
|     lines = [simplejson.loads(l) for l in lines] | ||||
|     lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip] | ||||
|     # Generate data structure accessed by data[data_type][index][metric] | ||||
|     data_types = [l[_TYPE] for l in lines] | ||||
|     data = {t: [] for t in data_types} | ||||
|     for t, line in zip(data_types, lines): | ||||
|         del line[_TYPE] | ||||
|         data[t].append(line) | ||||
|     # Generate data structure accessed by data[data_type][metric][index] | ||||
|     for t in data: | ||||
|         metrics = sorted(data[t][0].keys()) | ||||
|         err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics) | ||||
|         assert all(sorted(d.keys()) == metrics for d in data[t]), err_str | ||||
|         data[t] = {m: [d[m] for d in data[t]] for m in metrics} | ||||
|     return data | ||||
|  | ||||
|  | ||||
| def sort_log_data(data): | ||||
|     """Sort each data[data_type][metric] by epoch or keep only first instance.""" | ||||
|     for t in data: | ||||
|         if "epoch" in data[t]: | ||||
|             assert "epoch_ind" not in data[t] and "epoch_max" not in data[t] | ||||
|             data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]] | ||||
|             data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]] | ||||
|             epoch = data[t]["epoch_ind"] | ||||
|             if "iter" in data[t]: | ||||
|                 assert "iter_ind" not in data[t] and "iter_max" not in data[t] | ||||
|                 data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]] | ||||
|                 data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]] | ||||
|                 itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"]) | ||||
|                 epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr] | ||||
|             for m in data[t]: | ||||
|                 data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))] | ||||
|         else: | ||||
|             data[t] = {m: d[0] for m, d in data[t].items()} | ||||
|     return data | ||||
							
								
								
									
										435
									
								
								pycls/core/meters.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										435
									
								
								pycls/core/meters.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,435 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Meters.""" | ||||
|  | ||||
| from collections import deque | ||||
|  | ||||
| import numpy as np | ||||
| import pycls.core.logging as logging | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
| from pycls.core.timer import Timer | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def time_string(seconds): | ||||
|     """Converts time in seconds to a fixed-width string format.""" | ||||
|     days, rem = divmod(int(seconds), 24 * 3600) | ||||
|     hrs, rem = divmod(rem, 3600) | ||||
|     mins, secs = divmod(rem, 60) | ||||
|     return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) | ||||
|  | ||||
|  | ||||
| def inter_union(preds, labels, num_classes): | ||||
|     _, preds = torch.max(preds, 1) | ||||
|     preds = preds.type(torch.uint8) + 1 | ||||
|     labels = labels.type(torch.uint8) + 1 | ||||
|     preds = preds * (labels > 0).type(torch.uint8) | ||||
|  | ||||
|     inter = preds * (preds == labels).type(torch.uint8) | ||||
|     area_inter = torch.histc(inter.type(torch.int64), bins=num_classes, min=1, max=num_classes) | ||||
|     area_preds = torch.histc(preds.type(torch.int64), bins=num_classes, min=1, max=num_classes) | ||||
|     area_labels = torch.histc(labels.type(torch.int64), bins=num_classes, min=1, max=num_classes) | ||||
|     area_union = area_preds + area_labels - area_inter | ||||
|  | ||||
|     return [area_inter.type(torch.float64) / labels.size(0), area_union.type(torch.float64) / labels.size(0)] | ||||
|  | ||||
|  | ||||
| def topk_errors(preds, labels, ks): | ||||
|     """Computes the top-k error for each k.""" | ||||
|     err_str = "Batch dim of predictions and labels must match" | ||||
|     assert preds.size(0) == labels.size(0), err_str | ||||
|     # Find the top max_k predictions for each sample | ||||
|     _top_max_k_vals, top_max_k_inds = torch.topk( | ||||
|         preds, max(ks), dim=1, largest=True, sorted=True | ||||
|     ) | ||||
|     # (batch_size, max_k) -> (max_k, batch_size) | ||||
|     top_max_k_inds = top_max_k_inds.t() | ||||
|     # (batch_size, ) -> (max_k, batch_size) | ||||
|     rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) | ||||
|     # (i, j) = 1 if top i-th prediction for the j-th sample is correct | ||||
|     top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) | ||||
|     # Compute the number of topk correct predictions for each k | ||||
|     topks_correct = [top_max_k_correct[:k, :].view(-1).float().sum() for k in ks] | ||||
|     return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct] | ||||
|  | ||||
|  | ||||
| def gpu_mem_usage(): | ||||
|     """Computes the GPU memory usage for the current device (MB).""" | ||||
|     mem_usage_bytes = torch.cuda.max_memory_allocated() | ||||
|     return mem_usage_bytes / 1024 / 1024 | ||||
|  | ||||
|  | ||||
| class ScalarMeter(object): | ||||
|     """Measures a scalar value (adapted from Detectron).""" | ||||
|  | ||||
|     def __init__(self, window_size): | ||||
|         self.deque = deque(maxlen=window_size) | ||||
|         self.total = 0.0 | ||||
|         self.count = 0 | ||||
|  | ||||
|     def reset(self): | ||||
|         self.deque.clear() | ||||
|         self.total = 0.0 | ||||
|         self.count = 0 | ||||
|  | ||||
|     def add_value(self, value): | ||||
|         self.deque.append(value) | ||||
|         self.count += 1 | ||||
|         self.total += value | ||||
|  | ||||
|     def get_win_median(self): | ||||
|         return np.median(self.deque) | ||||
|  | ||||
|     def get_win_avg(self): | ||||
|         return np.mean(self.deque) | ||||
|  | ||||
|     def get_global_avg(self): | ||||
|         return self.total / self.count | ||||
|  | ||||
|  | ||||
| class TrainMeter(object): | ||||
|     """Measures training stats.""" | ||||
|  | ||||
|     def __init__(self, epoch_iters): | ||||
|         self.epoch_iters = epoch_iters | ||||
|         self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters | ||||
|         self.iter_timer = Timer() | ||||
|         self.loss = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         self.loss_total = 0.0 | ||||
|         self.lr = None | ||||
|         # Current minibatch errors (smoothed over a window) | ||||
|         self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         # Number of misclassified examples | ||||
|         self.num_top1_mis = 0 | ||||
|         self.num_top5_mis = 0 | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def reset(self, timer=False): | ||||
|         if timer: | ||||
|             self.iter_timer.reset() | ||||
|         self.loss.reset() | ||||
|         self.loss_total = 0.0 | ||||
|         self.lr = None | ||||
|         self.mb_top1_err.reset() | ||||
|         self.mb_top5_err.reset() | ||||
|         self.num_top1_mis = 0 | ||||
|         self.num_top5_mis = 0 | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def iter_tic(self): | ||||
|         self.iter_timer.tic() | ||||
|  | ||||
|     def iter_toc(self): | ||||
|         self.iter_timer.toc() | ||||
|  | ||||
|     def update_stats(self, top1_err, top5_err, loss, lr, mb_size): | ||||
|         # Current minibatch stats | ||||
|         self.mb_top1_err.add_value(top1_err) | ||||
|         self.mb_top5_err.add_value(top5_err) | ||||
|         self.loss.add_value(loss) | ||||
|         self.lr = lr | ||||
|         # Aggregate stats | ||||
|         self.num_top1_mis += top1_err * mb_size | ||||
|         self.num_top5_mis += top5_err * mb_size | ||||
|         self.loss_total += loss * mb_size | ||||
|         self.num_samples += mb_size | ||||
|  | ||||
|     def get_iter_stats(self, cur_epoch, cur_iter): | ||||
|         cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1 | ||||
|         eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "time_diff": self.iter_timer.diff, | ||||
|             "eta": time_string(eta_sec), | ||||
|             "top1_err": self.mb_top1_err.get_win_median(), | ||||
|             "top5_err": self.mb_top5_err.get_win_median(), | ||||
|             "loss": self.loss.get_win_median(), | ||||
|             "lr": self.lr, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_iter_stats(self, cur_epoch, cur_iter): | ||||
|         if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||||
|             return | ||||
|         stats = self.get_iter_stats(cur_epoch, cur_iter) | ||||
|         logger.info(logging.dump_log_data(stats, "train_iter")) | ||||
|  | ||||
|     def get_epoch_stats(self, cur_epoch): | ||||
|         cur_iter_total = (cur_epoch + 1) * self.epoch_iters | ||||
|         eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         top1_err = self.num_top1_mis / self.num_samples | ||||
|         top5_err = self.num_top5_mis / self.num_samples | ||||
|         avg_loss = self.loss_total / self.num_samples | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "eta": time_string(eta_sec), | ||||
|             "top1_err": top1_err, | ||||
|             "top5_err": top5_err, | ||||
|             "loss": avg_loss, | ||||
|             "lr": self.lr, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_epoch_stats(self, cur_epoch): | ||||
|         stats = self.get_epoch_stats(cur_epoch) | ||||
|         logger.info(logging.dump_log_data(stats, "train_epoch")) | ||||
|  | ||||
|  | ||||
| class TestMeter(object): | ||||
|     """Measures testing stats.""" | ||||
|  | ||||
|     def __init__(self, max_iter): | ||||
|         self.max_iter = max_iter | ||||
|         self.iter_timer = Timer() | ||||
|         # Current minibatch errors (smoothed over a window) | ||||
|         self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         # Min errors (over the full test set) | ||||
|         self.min_top1_err = 100.0 | ||||
|         self.min_top5_err = 100.0 | ||||
|         # Number of misclassified examples | ||||
|         self.num_top1_mis = 0 | ||||
|         self.num_top5_mis = 0 | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def reset(self, min_errs=False): | ||||
|         if min_errs: | ||||
|             self.min_top1_err = 100.0 | ||||
|             self.min_top5_err = 100.0 | ||||
|         self.iter_timer.reset() | ||||
|         self.mb_top1_err.reset() | ||||
|         self.mb_top5_err.reset() | ||||
|         self.num_top1_mis = 0 | ||||
|         self.num_top5_mis = 0 | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def iter_tic(self): | ||||
|         self.iter_timer.tic() | ||||
|  | ||||
|     def iter_toc(self): | ||||
|         self.iter_timer.toc() | ||||
|  | ||||
|     def update_stats(self, top1_err, top5_err, mb_size): | ||||
|         self.mb_top1_err.add_value(top1_err) | ||||
|         self.mb_top5_err.add_value(top5_err) | ||||
|         self.num_top1_mis += top1_err * mb_size | ||||
|         self.num_top5_mis += top5_err * mb_size | ||||
|         self.num_samples += mb_size | ||||
|  | ||||
|     def get_iter_stats(self, cur_epoch, cur_iter): | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         iter_stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "iter": "{}/{}".format(cur_iter + 1, self.max_iter), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "time_diff": self.iter_timer.diff, | ||||
|             "top1_err": self.mb_top1_err.get_win_median(), | ||||
|             "top5_err": self.mb_top5_err.get_win_median(), | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return iter_stats | ||||
|  | ||||
|     def log_iter_stats(self, cur_epoch, cur_iter): | ||||
|         if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||||
|             return | ||||
|         stats = self.get_iter_stats(cur_epoch, cur_iter) | ||||
|         logger.info(logging.dump_log_data(stats, "test_iter")) | ||||
|  | ||||
|     def get_epoch_stats(self, cur_epoch): | ||||
|         top1_err = self.num_top1_mis / self.num_samples | ||||
|         top5_err = self.num_top5_mis / self.num_samples | ||||
|         self.min_top1_err = min(self.min_top1_err, top1_err) | ||||
|         self.min_top5_err = min(self.min_top5_err, top5_err) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "top1_err": top1_err, | ||||
|             "top5_err": top5_err, | ||||
|             "min_top1_err": self.min_top1_err, | ||||
|             "min_top5_err": self.min_top5_err, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_epoch_stats(self, cur_epoch): | ||||
|         stats = self.get_epoch_stats(cur_epoch) | ||||
|         logger.info(logging.dump_log_data(stats, "test_epoch")) | ||||
|  | ||||
|  | ||||
| class TrainMeterIoU(object): | ||||
|     """Measures training stats.""" | ||||
|  | ||||
|     def __init__(self, epoch_iters): | ||||
|         self.epoch_iters = epoch_iters | ||||
|         self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters | ||||
|         self.iter_timer = Timer() | ||||
|         self.loss = ScalarMeter(cfg.LOG_PERIOD) | ||||
|         self.loss_total = 0.0 | ||||
|         self.lr = None | ||||
|  | ||||
|         self.mb_miou = ScalarMeter(cfg.LOG_PERIOD) | ||||
|  | ||||
|         self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def reset(self, timer=False): | ||||
|         if timer: | ||||
|             self.iter_timer.reset() | ||||
|         self.loss.reset() | ||||
|         self.loss_total = 0.0 | ||||
|         self.lr = None | ||||
|         self.mb_miou.reset() | ||||
|         self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def iter_tic(self): | ||||
|         self.iter_timer.tic() | ||||
|  | ||||
|     def iter_toc(self): | ||||
|         self.iter_timer.toc() | ||||
|  | ||||
|     def update_stats(self, inter, union, loss, lr, mb_size): | ||||
|         # Current minibatch stats | ||||
|         self.mb_miou.add_value((inter / (union + 1e-10)).mean()) | ||||
|         self.loss.add_value(loss) | ||||
|         self.lr = lr | ||||
|         # Aggregate stats | ||||
|         self.num_inter += inter * mb_size | ||||
|         self.num_union += union * mb_size | ||||
|         self.loss_total += loss * mb_size | ||||
|         self.num_samples += mb_size | ||||
|  | ||||
|     def get_iter_stats(self, cur_epoch, cur_iter): | ||||
|         cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1 | ||||
|         eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "time_diff": self.iter_timer.diff, | ||||
|             "eta": time_string(eta_sec), | ||||
|             "miou": self.mb_miou.get_win_median(), | ||||
|             "loss": self.loss.get_win_median(), | ||||
|             "lr": self.lr, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_iter_stats(self, cur_epoch, cur_iter): | ||||
|         if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||||
|             return | ||||
|         stats = self.get_iter_stats(cur_epoch, cur_iter) | ||||
|         logger.info(logging.dump_log_data(stats, "train_iter")) | ||||
|  | ||||
|     def get_epoch_stats(self, cur_epoch): | ||||
|         cur_iter_total = (cur_epoch + 1) * self.epoch_iters | ||||
|         eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         miou = (self.num_inter / (self.num_union + 1e-10)).mean() | ||||
|         avg_loss = self.loss_total / self.num_samples | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "eta": time_string(eta_sec), | ||||
|             "miou": miou, | ||||
|             "loss": avg_loss, | ||||
|             "lr": self.lr, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_epoch_stats(self, cur_epoch): | ||||
|         stats = self.get_epoch_stats(cur_epoch) | ||||
|         logger.info(logging.dump_log_data(stats, "train_epoch")) | ||||
|  | ||||
|  | ||||
| class TestMeterIoU(object): | ||||
|     """Measures testing stats.""" | ||||
|  | ||||
|     def __init__(self, max_iter): | ||||
|         self.max_iter = max_iter | ||||
|         self.iter_timer = Timer() | ||||
|  | ||||
|         self.mb_miou = ScalarMeter(cfg.LOG_PERIOD) | ||||
|  | ||||
|         self.max_miou = 0.0 | ||||
|  | ||||
|         self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def reset(self, min_errs=False): | ||||
|         if min_errs: | ||||
|             self.max_miou = 0.0 | ||||
|         self.iter_timer.reset() | ||||
|         self.mb_miou.reset() | ||||
|         self.num_inter = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_union = np.zeros(cfg.MODEL.NUM_CLASSES) | ||||
|         self.num_samples = 0 | ||||
|  | ||||
|     def iter_tic(self): | ||||
|         self.iter_timer.tic() | ||||
|  | ||||
|     def iter_toc(self): | ||||
|         self.iter_timer.toc() | ||||
|  | ||||
|     def update_stats(self, inter, union, mb_size): | ||||
|         self.mb_miou.add_value((inter / (union + 1e-10)).mean()) | ||||
|         self.num_inter += inter * mb_size | ||||
|         self.num_union += union * mb_size | ||||
|         self.num_samples += mb_size | ||||
|  | ||||
|     def get_iter_stats(self, cur_epoch, cur_iter): | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         iter_stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "iter": "{}/{}".format(cur_iter + 1, self.max_iter), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "time_diff": self.iter_timer.diff, | ||||
|             "miou": self.mb_miou.get_win_median(), | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return iter_stats | ||||
|  | ||||
|     def log_iter_stats(self, cur_epoch, cur_iter): | ||||
|         if (cur_iter + 1) % cfg.LOG_PERIOD != 0: | ||||
|             return | ||||
|         stats = self.get_iter_stats(cur_epoch, cur_iter) | ||||
|         logger.info(logging.dump_log_data(stats, "test_iter")) | ||||
|  | ||||
|     def get_epoch_stats(self, cur_epoch): | ||||
|         miou = (self.num_inter / (self.num_union + 1e-10)).mean() | ||||
|         self.max_miou = max(self.max_miou, miou) | ||||
|         mem_usage = gpu_mem_usage() | ||||
|         stats = { | ||||
|             "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), | ||||
|             "time_avg": self.iter_timer.average_time, | ||||
|             "miou": miou, | ||||
|             "max_miou": self.max_miou, | ||||
|             "mem": int(np.ceil(mem_usage)), | ||||
|         } | ||||
|         return stats | ||||
|  | ||||
|     def log_epoch_stats(self, cur_epoch): | ||||
|         stats = self.get_epoch_stats(cur_epoch) | ||||
|         logger.info(logging.dump_log_data(stats, "test_epoch")) | ||||
							
								
								
									
										129
									
								
								pycls/core/net.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								pycls/core/net.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Functions for manipulating networks.""" | ||||
|  | ||||
| import itertools | ||||
| import math | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def init_weights(m): | ||||
|     """Performs ResNet-style weight initialization.""" | ||||
|     if isinstance(m, nn.Conv2d): | ||||
|         # Note that there is no bias due to BN | ||||
|         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
|         m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) | ||||
|     elif isinstance(m, nn.BatchNorm2d): | ||||
|         zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA | ||||
|         zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma | ||||
|         m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) | ||||
|         m.bias.data.zero_() | ||||
|     elif isinstance(m, nn.Linear): | ||||
|         m.weight.data.normal_(mean=0.0, std=0.01) | ||||
|         m.bias.data.zero_() | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def compute_precise_bn_stats(model, loader): | ||||
|     """Computes precise BN stats on training data.""" | ||||
|     # Compute the number of minibatches to use | ||||
|     num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader)) | ||||
|     # Retrieve the BN layers | ||||
|     bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] | ||||
|     # Initialize stats storage | ||||
|     mus = [torch.zeros_like(bn.running_mean) for bn in bns] | ||||
|     sqs = [torch.zeros_like(bn.running_var) for bn in bns] | ||||
|     # Remember momentum values | ||||
|     moms = [bn.momentum for bn in bns] | ||||
|     # Disable momentum | ||||
|     for bn in bns: | ||||
|         bn.momentum = 1.0 | ||||
|     # Accumulate the stats across the data samples | ||||
|     for inputs, _labels in itertools.islice(loader, num_iter): | ||||
|         model(inputs.cuda()) | ||||
|         # Accumulate the stats for each BN layer | ||||
|         for i, bn in enumerate(bns): | ||||
|             m, v = bn.running_mean, bn.running_var | ||||
|             sqs[i] += (v + m * m) / num_iter | ||||
|             mus[i] += m / num_iter | ||||
|     # Set the stats and restore momentum values | ||||
|     for i, bn in enumerate(bns): | ||||
|         bn.running_var = sqs[i] - mus[i] * mus[i] | ||||
|         bn.running_mean = mus[i] | ||||
|         bn.momentum = moms[i] | ||||
|  | ||||
|  | ||||
| def reset_bn_stats(model): | ||||
|     """Resets running BN stats.""" | ||||
|     for m in model.modules(): | ||||
|         if isinstance(m, torch.nn.BatchNorm2d): | ||||
|             m.reset_running_stats() | ||||
|  | ||||
|  | ||||
| def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False): | ||||
|     """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts).""" | ||||
|     h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] | ||||
|     h = (h + 2 * padding - k) // stride + 1 | ||||
|     w = (w + 2 * padding - k) // stride + 1 | ||||
|     flops += k * k * w_in * w_out * h * w // groups | ||||
|     params += k * k * w_in * w_out // groups | ||||
|     flops += w_out if bias else 0 | ||||
|     params += w_out if bias else 0 | ||||
|     acts += w_out * h * w | ||||
|     return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} | ||||
|  | ||||
|  | ||||
| def complexity_batchnorm2d(cx, w_in): | ||||
|     """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts).""" | ||||
|     h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] | ||||
|     params += 2 * w_in | ||||
|     return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} | ||||
|  | ||||
|  | ||||
| def complexity_maxpool2d(cx, k, stride, padding): | ||||
|     """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts).""" | ||||
|     h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] | ||||
|     h = (h + 2 * padding - k) // stride + 1 | ||||
|     w = (w + 2 * padding - k) // stride + 1 | ||||
|     return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} | ||||
|  | ||||
|  | ||||
| def complexity(model): | ||||
|     """Compute model complexity (model can be model instance or model class).""" | ||||
|     size = cfg.TRAIN.IM_SIZE | ||||
|     cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} | ||||
|     cx = model.complexity(cx) | ||||
|     return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} | ||||
|  | ||||
|  | ||||
| def drop_connect(x, drop_ratio): | ||||
|     """Drop connect (adapted from DARTS).""" | ||||
|     keep_ratio = 1.0 - drop_ratio | ||||
|     mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) | ||||
|     mask.bernoulli_(keep_ratio) | ||||
|     x.div_(keep_ratio) | ||||
|     x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| def get_flat_weights(model): | ||||
|     """Gets all model weights as a single flat vector.""" | ||||
|     return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) | ||||
|  | ||||
|  | ||||
| def set_flat_weights(model, flat_weights): | ||||
|     """Sets all model weights from a single flat vector.""" | ||||
|     k = 0 | ||||
|     for p in model.parameters(): | ||||
|         n = p.data.numel() | ||||
|         p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) | ||||
|         k += n | ||||
|     assert k == flat_weights.numel() | ||||
							
								
								
									
										95
									
								
								pycls/core/optimizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								pycls/core/optimizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,95 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Optimizer.""" | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| def construct_optimizer(model): | ||||
|     """Constructs the optimizer. | ||||
|  | ||||
|     Note that the momentum update in PyTorch differs from the one in Caffe2. | ||||
|     In particular, | ||||
|  | ||||
|         Caffe2: | ||||
|             V := mu * V + lr * g | ||||
|             p := p - V | ||||
|  | ||||
|         PyTorch: | ||||
|             V := mu * V + g | ||||
|             p := p - lr * V | ||||
|  | ||||
|     where V is the velocity, mu is the momentum factor, lr is the learning rate, | ||||
|     g is the gradient and p are the parameters. | ||||
|  | ||||
|     Since V is defined independently of the learning rate in PyTorch, | ||||
|     when the learning rate is changed there is no need to perform the | ||||
|     momentum correction by scaling V (unlike in the Caffe2 case). | ||||
|     """ | ||||
|     if cfg.BN.USE_CUSTOM_WEIGHT_DECAY: | ||||
|         # Apply different weight decay to Batchnorm and non-batchnorm parameters. | ||||
|         p_bn = [p for n, p in model.named_parameters() if "bn" in n] | ||||
|         p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n] | ||||
|         optim_params = [ | ||||
|             {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY}, | ||||
|             {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, | ||||
|         ] | ||||
|     else: | ||||
|         optim_params = model.parameters() | ||||
|     return torch.optim.SGD( | ||||
|         optim_params, | ||||
|         lr=cfg.OPTIM.BASE_LR, | ||||
|         momentum=cfg.OPTIM.MOMENTUM, | ||||
|         weight_decay=cfg.OPTIM.WEIGHT_DECAY, | ||||
|         dampening=cfg.OPTIM.DAMPENING, | ||||
|         nesterov=cfg.OPTIM.NESTEROV, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def lr_fun_steps(cur_epoch): | ||||
|     """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" | ||||
|     ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] | ||||
|     return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) | ||||
|  | ||||
|  | ||||
| def lr_fun_exp(cur_epoch): | ||||
|     """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" | ||||
|     return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) | ||||
|  | ||||
|  | ||||
| def lr_fun_cos(cur_epoch): | ||||
|     """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" | ||||
|     base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH | ||||
|     return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) | ||||
|  | ||||
|  | ||||
| def get_lr_fun(): | ||||
|     """Retrieves the specified lr policy function""" | ||||
|     lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY | ||||
|     if lr_fun not in globals(): | ||||
|         raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) | ||||
|     return globals()[lr_fun] | ||||
|  | ||||
|  | ||||
| def get_epoch_lr(cur_epoch): | ||||
|     """Retrieves the lr for the given epoch according to the policy.""" | ||||
|     lr = get_lr_fun()(cur_epoch) | ||||
|     # Linear warmup | ||||
|     if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: | ||||
|         alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS | ||||
|         warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha | ||||
|         lr *= warmup_factor | ||||
|     return lr | ||||
|  | ||||
|  | ||||
| def set_lr(optimizer, new_lr): | ||||
|     """Sets the optimizer lr to the specified value.""" | ||||
|     for param_group in optimizer.param_groups: | ||||
|         param_group["lr"] = new_lr | ||||
							
								
								
									
										132
									
								
								pycls/core/plotting.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								pycls/core/plotting.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Plotting functions.""" | ||||
|  | ||||
| import colorlover as cl | ||||
| import matplotlib.pyplot as plt | ||||
| import plotly.graph_objs as go | ||||
| import plotly.offline as offline | ||||
| import pycls.core.logging as logging | ||||
|  | ||||
|  | ||||
| def get_plot_colors(max_colors, color_format="pyplot"): | ||||
|     """Generate colors for plotting.""" | ||||
|     colors = cl.scales["11"]["qual"]["Paired"] | ||||
|     if max_colors > len(colors): | ||||
|         colors = cl.to_rgb(cl.interp(colors, max_colors)) | ||||
|     if color_format == "pyplot": | ||||
|         return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] | ||||
|     return colors | ||||
|  | ||||
|  | ||||
| def prepare_plot_data(log_files, names, metric="top1_err"): | ||||
|     """Load logs and extract data for plotting error curves.""" | ||||
|     plot_data = [] | ||||
|     for file, name in zip(log_files, names): | ||||
|         d, data = {}, logging.sort_log_data(logging.load_log_data(file)) | ||||
|         for phase in ["train", "test"]: | ||||
|             x = data[phase + "_epoch"]["epoch_ind"] | ||||
|             y = data[phase + "_epoch"][metric] | ||||
|             d["x_" + phase], d["y_" + phase] = x, y | ||||
|             d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name | ||||
|         plot_data.append(d) | ||||
|     assert len(plot_data) > 0, "No data to plot" | ||||
|     return plot_data | ||||
|  | ||||
|  | ||||
| def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"): | ||||
|     """Plot error curves using plotly and save to file.""" | ||||
|     plot_data = prepare_plot_data(log_files, names, metric) | ||||
|     colors = get_plot_colors(len(plot_data), "plotly") | ||||
|     # Prepare data for plots (3 sets, train duplicated w and w/o legend) | ||||
|     data = [] | ||||
|     for i, d in enumerate(plot_data): | ||||
|         s = str(i) | ||||
|         line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} | ||||
|         line_test = {"color": colors[i], "dash": "solid", "width": 1.5} | ||||
|         data.append( | ||||
|             go.Scatter( | ||||
|                 x=d["x_train"], | ||||
|                 y=d["y_train"], | ||||
|                 mode="lines", | ||||
|                 name=d["train_label"], | ||||
|                 line=line_train, | ||||
|                 legendgroup=s, | ||||
|                 visible=True, | ||||
|                 showlegend=False, | ||||
|             ) | ||||
|         ) | ||||
|         data.append( | ||||
|             go.Scatter( | ||||
|                 x=d["x_test"], | ||||
|                 y=d["y_test"], | ||||
|                 mode="lines", | ||||
|                 name=d["test_label"], | ||||
|                 line=line_test, | ||||
|                 legendgroup=s, | ||||
|                 visible=True, | ||||
|                 showlegend=True, | ||||
|             ) | ||||
|         ) | ||||
|         data.append( | ||||
|             go.Scatter( | ||||
|                 x=d["x_train"], | ||||
|                 y=d["y_train"], | ||||
|                 mode="lines", | ||||
|                 name=d["train_label"], | ||||
|                 line=line_train, | ||||
|                 legendgroup=s, | ||||
|                 visible=False, | ||||
|                 showlegend=True, | ||||
|             ) | ||||
|         ) | ||||
|     # Prepare layout w ability to toggle 'all', 'train', 'test' | ||||
|     titlefont = {"size": 18, "color": "#7f7f7f"} | ||||
|     vis = [[True, True, False], [False, False, True], [False, True, False]] | ||||
|     buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) | ||||
|     buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons] | ||||
|     layout = go.Layout( | ||||
|         title=metric + " vs. epoch<br>[dash=train, solid=test]", | ||||
|         xaxis={"title": "epoch", "titlefont": titlefont}, | ||||
|         yaxis={"title": metric, "titlefont": titlefont}, | ||||
|         showlegend=True, | ||||
|         hoverlabel={"namelength": -1}, | ||||
|         updatemenus=[ | ||||
|             { | ||||
|                 "buttons": buttons, | ||||
|                 "direction": "down", | ||||
|                 "showactive": True, | ||||
|                 "x": 1.02, | ||||
|                 "xanchor": "left", | ||||
|                 "y": 1.08, | ||||
|                 "yanchor": "top", | ||||
|             } | ||||
|         ], | ||||
|     ) | ||||
|     # Create plotly plot | ||||
|     offline.plot({"data": data, "layout": layout}, filename=filename) | ||||
|  | ||||
|  | ||||
| def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"): | ||||
|     """Plot error curves using matplotlib.pyplot and save to file.""" | ||||
|     plot_data = prepare_plot_data(log_files, names, metric) | ||||
|     colors = get_plot_colors(len(names)) | ||||
|     for ind, d in enumerate(plot_data): | ||||
|         c, lbl = colors[ind], d["test_label"] | ||||
|         plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) | ||||
|         plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) | ||||
|     plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14) | ||||
|     plt.xlabel("epoch", fontsize=14) | ||||
|     plt.ylabel(metric, fontsize=14) | ||||
|     plt.grid(alpha=0.4) | ||||
|     plt.legend() | ||||
|     if filename: | ||||
|         plt.savefig(filename) | ||||
|         plt.clf() | ||||
|     else: | ||||
|         plt.show() | ||||
							
								
								
									
										39
									
								
								pycls/core/timer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								pycls/core/timer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Timer.""" | ||||
|  | ||||
| import time | ||||
|  | ||||
|  | ||||
| class Timer(object): | ||||
|     """A simple timer (adapted from Detectron).""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.total_time = None | ||||
|         self.calls = None | ||||
|         self.start_time = None | ||||
|         self.diff = None | ||||
|         self.average_time = None | ||||
|         self.reset() | ||||
|  | ||||
|     def tic(self): | ||||
|         # using time.time as time.clock does not normalize for multithreading | ||||
|         self.start_time = time.time() | ||||
|  | ||||
|     def toc(self): | ||||
|         self.diff = time.time() - self.start_time | ||||
|         self.total_time += self.diff | ||||
|         self.calls += 1 | ||||
|         self.average_time = self.total_time / self.calls | ||||
|  | ||||
|     def reset(self): | ||||
|         self.total_time = 0.0 | ||||
|         self.calls = 0 | ||||
|         self.start_time = 0.0 | ||||
|         self.diff = 0.0 | ||||
|         self.average_time = 0.0 | ||||
							
								
								
									
										419
									
								
								pycls/core/trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										419
									
								
								pycls/core/trainer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,419 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| """Tools for training and testing a model.""" | ||||
|  | ||||
| import os | ||||
| from thop import profile | ||||
|  | ||||
| import numpy as np | ||||
| import pycls.core.benchmark as benchmark | ||||
| import pycls.core.builders as builders | ||||
| import pycls.core.checkpoint as checkpoint | ||||
| import pycls.core.config as config | ||||
| import pycls.core.distributed as dist | ||||
| import pycls.core.logging as logging | ||||
| import pycls.core.meters as meters | ||||
| import pycls.core.net as net | ||||
| import pycls.core.optimizer as optim | ||||
| import pycls.datasets.loader as loader | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from pycls.core.config import cfg | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def setup_env(): | ||||
|     """Sets up environment for training or testing.""" | ||||
|     if dist.is_master_proc(): | ||||
|         # Ensure that the output dir exists | ||||
|         os.makedirs(cfg.OUT_DIR, exist_ok=True) | ||||
|         # Save the config | ||||
|         config.dump_cfg() | ||||
|     # Setup logging | ||||
|     logging.setup_logging() | ||||
|     # Log the config as both human readable and as a json | ||||
|     logger.info("Config:\n{}".format(cfg)) | ||||
|     logger.info(logging.dump_log_data(cfg, "cfg")) | ||||
|     # Fix the RNG seeds (see RNG comment in core/config.py for discussion) | ||||
|     np.random.seed(cfg.RNG_SEED) | ||||
|     torch.manual_seed(cfg.RNG_SEED) | ||||
|     # Configure the CUDNN backend | ||||
|     torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK | ||||
|  | ||||
|  | ||||
| def setup_model(): | ||||
|     """Sets up a model for training or testing and log the results.""" | ||||
|     # Build the model | ||||
|     model = builders.build_model() | ||||
|     logger.info("Model:\n{}".format(model)) | ||||
|     # Log model complexity | ||||
|     # logger.info(logging.dump_log_data(net.complexity(model), "complexity")) | ||||
|     if cfg.TASK == "seg" and cfg.TRAIN.DATASET == "cityscapes": | ||||
|         h, w = 1025, 2049 | ||||
|     else: | ||||
|         h, w = cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE | ||||
|     if cfg.TASK == "jig": | ||||
|         x = torch.randn(1, cfg.JIGSAW_GRID ** 2, cfg.MODEL.INPUT_CHANNELS, h, w) | ||||
|     else: | ||||
|         x = torch.randn(1, cfg.MODEL.INPUT_CHANNELS, h, w) | ||||
|     macs, params = profile(model, inputs=(x, ), verbose=False) | ||||
|     logger.info("Params: {:,}".format(params)) | ||||
|     logger.info("Flops: {:,}".format(macs)) | ||||
|     # Transfer the model to the current GPU device | ||||
|     err_str = "Cannot use more GPU devices than available" | ||||
|     assert cfg.NUM_GPUS <= torch.cuda.device_count(), err_str | ||||
|     cur_device = torch.cuda.current_device() | ||||
|     model = model.cuda(device=cur_device) | ||||
|     # Use multi-process data parallel model in the multi-gpu setting | ||||
|     if cfg.NUM_GPUS > 1: | ||||
|         # Make model replica operate on the current device | ||||
|         model = torch.nn.parallel.DistributedDataParallel( | ||||
|             module=model, device_ids=[cur_device], output_device=cur_device | ||||
|         ) | ||||
|         # Set complexity function to be module's complexity function | ||||
|         # model.complexity = model.module.complexity | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch): | ||||
|     """Performs one epoch of training.""" | ||||
|     # Update drop path prob for NAS | ||||
|     if cfg.MODEL.TYPE == "nas": | ||||
|         m = model.module if cfg.NUM_GPUS > 1 else model | ||||
|         m.set_drop_path_prob(cfg.NAS.DROP_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH) | ||||
|     # Shuffle the data | ||||
|     loader.shuffle(train_loader, cur_epoch) | ||||
|     # Update the learning rate per epoch | ||||
|     if not cfg.OPTIM.ITER_LR: | ||||
|         lr = optim.get_epoch_lr(cur_epoch) | ||||
|         optim.set_lr(optimizer, lr) | ||||
|     # Enable training mode | ||||
|     model.train() | ||||
|     train_meter.iter_tic() | ||||
|     for cur_iter, (inputs, labels) in enumerate(train_loader): | ||||
|         # Update the learning rate per iter | ||||
|         if cfg.OPTIM.ITER_LR: | ||||
|             lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader)) | ||||
|             optim.set_lr(optimizer, lr) | ||||
|         # Transfer the data to the current GPU device | ||||
|         inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) | ||||
|         # Perform the forward pass | ||||
|         preds = model(inputs) | ||||
|         # Compute the loss | ||||
|         if isinstance(preds, tuple): | ||||
|             loss = loss_fun(preds[0], labels) + cfg.NAS.AUX_WEIGHT * loss_fun(preds[1], labels) | ||||
|             preds = preds[0] | ||||
|         else: | ||||
|             loss = loss_fun(preds, labels) | ||||
|         # Perform the backward pass | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         # Update the parameters | ||||
|         optimizer.step() | ||||
|         # Compute the errors | ||||
|         if cfg.TASK == "col": | ||||
|             preds = preds.permute(0, 2, 3, 1) | ||||
|             preds = preds.reshape(-1, preds.size(3)) | ||||
|             labels = labels.reshape(-1) | ||||
|             mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS | ||||
|         else: | ||||
|             mb_size = inputs.size(0) * cfg.NUM_GPUS | ||||
|         if cfg.TASK == "seg": | ||||
|             # top1_err is in fact inter; top5_err is in fact union | ||||
|             top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES) | ||||
|         else: | ||||
|             ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes | ||||
|             top1_err, top5_err = meters.topk_errors(preds, labels, ks) | ||||
|         # Combine the stats across the GPUs (no reduction if 1 GPU used) | ||||
|         loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err]) | ||||
|         # Copy the stats from GPU to CPU (sync point) | ||||
|         loss = loss.item() | ||||
|         if cfg.TASK == "seg": | ||||
|             top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy() | ||||
|         else: | ||||
|             top1_err, top5_err = top1_err.item(), top5_err.item() | ||||
|         train_meter.iter_toc() | ||||
|         # Update and log stats | ||||
|         train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size) | ||||
|         train_meter.log_iter_stats(cur_epoch, cur_iter) | ||||
|         train_meter.iter_tic() | ||||
|     # Log epoch stats | ||||
|     train_meter.log_epoch_stats(cur_epoch) | ||||
|     train_meter.reset() | ||||
|  | ||||
|  | ||||
| def search_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch): | ||||
|     """Performs one epoch of differentiable architecture search.""" | ||||
|     m = model.module if cfg.NUM_GPUS > 1 else model | ||||
|     # Shuffle the data | ||||
|     loader.shuffle(train_loader[0], cur_epoch) | ||||
|     loader.shuffle(train_loader[1], cur_epoch) | ||||
|     # Update the learning rate per epoch | ||||
|     if not cfg.OPTIM.ITER_LR: | ||||
|         lr = optim.get_epoch_lr(cur_epoch) | ||||
|         optim.set_lr(optimizer[0], lr) | ||||
|     # Enable training mode | ||||
|     model.train() | ||||
|     train_meter.iter_tic() | ||||
|     trainB_iter = iter(train_loader[1]) | ||||
|     for cur_iter, (inputs, labels) in enumerate(train_loader[0]): | ||||
|         # Update the learning rate per iter | ||||
|         if cfg.OPTIM.ITER_LR: | ||||
|             lr = optim.get_epoch_lr(cur_epoch + cur_iter / len(train_loader[0])) | ||||
|             optim.set_lr(optimizer[0], lr) | ||||
|         # Transfer the data to the current GPU device | ||||
|         inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) | ||||
|         # Update architecture | ||||
|         if cur_epoch + cur_iter / len(train_loader[0]) >= cfg.OPTIM.ARCH_EPOCH: | ||||
|             try: | ||||
|                 inputsB, labelsB = next(trainB_iter) | ||||
|             except StopIteration: | ||||
|                 trainB_iter = iter(train_loader[1]) | ||||
|                 inputsB, labelsB = next(trainB_iter) | ||||
|             inputsB, labelsB = inputsB.cuda(), labelsB.cuda(non_blocking=True) | ||||
|             optimizer[1].zero_grad() | ||||
|             loss = m._loss(inputsB, labelsB) | ||||
|             loss.backward() | ||||
|             optimizer[1].step() | ||||
|         # Perform the forward pass | ||||
|         preds = model(inputs) | ||||
|         # Compute the loss | ||||
|         loss = loss_fun(preds, labels) | ||||
|         # Perform the backward pass | ||||
|         optimizer[0].zero_grad() | ||||
|         loss.backward() | ||||
|         torch.nn.utils.clip_grad_norm(model.parameters(), 5.0) | ||||
|         # Update the parameters | ||||
|         optimizer[0].step() | ||||
|         # Compute the errors | ||||
|         if cfg.TASK == "col": | ||||
|             preds = preds.permute(0, 2, 3, 1) | ||||
|             preds = preds.reshape(-1, preds.size(3)) | ||||
|             labels = labels.reshape(-1) | ||||
|             mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS | ||||
|         else: | ||||
|             mb_size = inputs.size(0) * cfg.NUM_GPUS | ||||
|         if cfg.TASK == "seg": | ||||
|             # top1_err is in fact inter; top5_err is in fact union | ||||
|             top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES) | ||||
|         else: | ||||
|             ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes | ||||
|             top1_err, top5_err = meters.topk_errors(preds, labels, ks) | ||||
|         # Combine the stats across the GPUs (no reduction if 1 GPU used) | ||||
|         loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err]) | ||||
|         # Copy the stats from GPU to CPU (sync point) | ||||
|         loss = loss.item() | ||||
|         if cfg.TASK == "seg": | ||||
|             top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy() | ||||
|         else: | ||||
|             top1_err, top5_err = top1_err.item(), top5_err.item() | ||||
|         train_meter.iter_toc() | ||||
|         # Update and log stats | ||||
|         train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size) | ||||
|         train_meter.log_iter_stats(cur_epoch, cur_iter) | ||||
|         train_meter.iter_tic() | ||||
|     # Log epoch stats | ||||
|     train_meter.log_epoch_stats(cur_epoch) | ||||
|     train_meter.reset() | ||||
|     # Log genotype | ||||
|     genotype = m.genotype() | ||||
|     logger.info("genotype = %s", genotype) | ||||
|     logger.info(F.softmax(m.net_.alphas_normal, dim=-1)) | ||||
|     logger.info(F.softmax(m.net_.alphas_reduce, dim=-1)) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def test_epoch(test_loader, model, test_meter, cur_epoch): | ||||
|     """Evaluates the model on the test set.""" | ||||
|     # Enable eval mode | ||||
|     model.eval() | ||||
|     test_meter.iter_tic() | ||||
|     for cur_iter, (inputs, labels) in enumerate(test_loader): | ||||
|         # Transfer the data to the current GPU device | ||||
|         inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) | ||||
|         # Compute the predictions | ||||
|         preds = model(inputs) | ||||
|         # Compute the errors | ||||
|         if cfg.TASK == "col": | ||||
|             preds = preds.permute(0, 2, 3, 1) | ||||
|             preds = preds.reshape(-1, preds.size(3)) | ||||
|             labels = labels.reshape(-1) | ||||
|             mb_size = inputs.size(0) * inputs.size(2) * inputs.size(3) * cfg.NUM_GPUS | ||||
|         else: | ||||
|             mb_size = inputs.size(0) * cfg.NUM_GPUS | ||||
|         if cfg.TASK == "seg": | ||||
|             # top1_err is in fact inter; top5_err is in fact union | ||||
|             top1_err, top5_err = meters.inter_union(preds, labels, cfg.MODEL.NUM_CLASSES) | ||||
|         else: | ||||
|             ks = [1, min(5, cfg.MODEL.NUM_CLASSES)]  # rot only has 4 classes | ||||
|             top1_err, top5_err = meters.topk_errors(preds, labels, ks) | ||||
|         # Combine the errors across the GPUs  (no reduction if 1 GPU used) | ||||
|         top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err]) | ||||
|         # Copy the errors from GPU to CPU (sync point) | ||||
|         if cfg.TASK == "seg": | ||||
|             top1_err, top5_err = top1_err.cpu().numpy(), top5_err.cpu().numpy() | ||||
|         else: | ||||
|             top1_err, top5_err = top1_err.item(), top5_err.item() | ||||
|         test_meter.iter_toc() | ||||
|         # Update and log stats | ||||
|         test_meter.update_stats(top1_err, top5_err, mb_size) | ||||
|         test_meter.log_iter_stats(cur_epoch, cur_iter) | ||||
|         test_meter.iter_tic() | ||||
|     # Log epoch stats | ||||
|     test_meter.log_epoch_stats(cur_epoch) | ||||
|     test_meter.reset() | ||||
|  | ||||
|  | ||||
| def train_model(): | ||||
|     """Trains the model.""" | ||||
|     # Setup training/testing environment | ||||
|     setup_env() | ||||
|     # Construct the model, loss_fun, and optimizer | ||||
|     model = setup_model() | ||||
|     loss_fun = builders.build_loss_fun().cuda() | ||||
|     if "search" in cfg.MODEL.TYPE: | ||||
|         params_w = [v for k, v in model.named_parameters() if "alphas" not in k] | ||||
|         params_a = [v for k, v in model.named_parameters() if "alphas" in k] | ||||
|         optimizer_w = torch.optim.SGD( | ||||
|             params=params_w, | ||||
|             lr=cfg.OPTIM.BASE_LR, | ||||
|             momentum=cfg.OPTIM.MOMENTUM, | ||||
|             weight_decay=cfg.OPTIM.WEIGHT_DECAY, | ||||
|             dampening=cfg.OPTIM.DAMPENING, | ||||
|             nesterov=cfg.OPTIM.NESTEROV | ||||
|         ) | ||||
|         if cfg.OPTIM.ARCH_OPTIM == "adam": | ||||
|             optimizer_a = torch.optim.Adam( | ||||
|                 params=params_a, | ||||
|                 lr=cfg.OPTIM.ARCH_BASE_LR, | ||||
|                 betas=(0.5, 0.999), | ||||
|                 weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY | ||||
|             ) | ||||
|         elif cfg.OPTIM.ARCH_OPTIM == "sgd": | ||||
|             optimizer_a = torch.optim.SGD( | ||||
|                 params=params_a, | ||||
|                 lr=cfg.OPTIM.ARCH_BASE_LR, | ||||
|                 momentum=cfg.OPTIM.MOMENTUM, | ||||
|                 weight_decay=cfg.OPTIM.ARCH_WEIGHT_DECAY, | ||||
|                 dampening=cfg.OPTIM.DAMPENING, | ||||
|                 nesterov=cfg.OPTIM.NESTEROV | ||||
|             ) | ||||
|         optimizer = [optimizer_w, optimizer_a] | ||||
|     else: | ||||
|         optimizer = optim.construct_optimizer(model) | ||||
|     # Load checkpoint or initial weights | ||||
|     start_epoch = 0 | ||||
|     if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): | ||||
|         last_checkpoint = checkpoint.get_last_checkpoint() | ||||
|         checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) | ||||
|         logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) | ||||
|         start_epoch = checkpoint_epoch + 1 | ||||
|     elif cfg.TRAIN.WEIGHTS: | ||||
|         checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) | ||||
|         logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS)) | ||||
|     # Create data loaders and meters | ||||
|     if cfg.TRAIN.PORTION < 1: | ||||
|         if "search" in cfg.MODEL.TYPE: | ||||
|             train_loader = [loader._construct_loader( | ||||
|                 dataset_name=cfg.TRAIN.DATASET, | ||||
|                 split=cfg.TRAIN.SPLIT, | ||||
|                 batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), | ||||
|                 shuffle=True, | ||||
|                 drop_last=True, | ||||
|                 portion=cfg.TRAIN.PORTION, | ||||
|                 side="l" | ||||
|             ), | ||||
|             loader._construct_loader( | ||||
|                 dataset_name=cfg.TRAIN.DATASET, | ||||
|                 split=cfg.TRAIN.SPLIT, | ||||
|                 batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), | ||||
|                 shuffle=True, | ||||
|                 drop_last=True, | ||||
|                 portion=cfg.TRAIN.PORTION, | ||||
|                 side="r" | ||||
|             )] | ||||
|         else: | ||||
|             train_loader = loader._construct_loader( | ||||
|                 dataset_name=cfg.TRAIN.DATASET, | ||||
|                 split=cfg.TRAIN.SPLIT, | ||||
|                 batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), | ||||
|                 shuffle=True, | ||||
|                 drop_last=True, | ||||
|                 portion=cfg.TRAIN.PORTION, | ||||
|                 side="l" | ||||
|             ) | ||||
|         test_loader = loader._construct_loader( | ||||
|             dataset_name=cfg.TRAIN.DATASET, | ||||
|             split=cfg.TRAIN.SPLIT, | ||||
|             batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), | ||||
|             shuffle=False, | ||||
|             drop_last=False, | ||||
|             portion=cfg.TRAIN.PORTION, | ||||
|             side="r" | ||||
|         ) | ||||
|     else: | ||||
|         train_loader = loader.construct_train_loader() | ||||
|         test_loader = loader.construct_test_loader() | ||||
|     train_meter_type = meters.TrainMeterIoU if cfg.TASK == "seg" else meters.TrainMeter | ||||
|     test_meter_type = meters.TestMeterIoU if cfg.TASK == "seg" else meters.TestMeter | ||||
|     l = train_loader[0] if isinstance(train_loader, list) else train_loader | ||||
|     train_meter = train_meter_type(len(l)) | ||||
|     test_meter = test_meter_type(len(test_loader)) | ||||
|     # Compute model and loader timings | ||||
|     if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: | ||||
|         l = train_loader[0] if isinstance(train_loader, list) else train_loader | ||||
|         benchmark.compute_time_full(model, loss_fun, l, test_loader) | ||||
|     # Perform the training loop | ||||
|     logger.info("Start epoch: {}".format(start_epoch + 1)) | ||||
|     for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): | ||||
|         # Train for one epoch | ||||
|         f = search_epoch if "search" in cfg.MODEL.TYPE else train_epoch | ||||
|         f(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) | ||||
|         # Compute precise BN stats | ||||
|         if cfg.BN.USE_PRECISE_STATS: | ||||
|             net.compute_precise_bn_stats(model, train_loader) | ||||
|         # Save a checkpoint | ||||
|         if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: | ||||
|             checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch) | ||||
|             logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) | ||||
|         # Evaluate the model | ||||
|         next_epoch = cur_epoch + 1 | ||||
|         if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: | ||||
|             test_epoch(test_loader, model, test_meter, cur_epoch) | ||||
|  | ||||
|  | ||||
| def test_model(): | ||||
|     """Evaluates a trained model.""" | ||||
|     # Setup training/testing environment | ||||
|     setup_env() | ||||
|     # Construct the model | ||||
|     model = setup_model() | ||||
|     # Load model weights | ||||
|     checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) | ||||
|     logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) | ||||
|     # Create data loaders and meters | ||||
|     test_loader = loader.construct_test_loader() | ||||
|     test_meter = meters.TestMeter(len(test_loader)) | ||||
|     # Evaluate the model | ||||
|     test_epoch(test_loader, model, test_meter, 0) | ||||
|  | ||||
|  | ||||
| def time_model(): | ||||
|     """Times model and data loader.""" | ||||
|     # Setup training/testing environment | ||||
|     setup_env() | ||||
|     # Construct the model and loss_fun | ||||
|     model = setup_model() | ||||
|     loss_fun = builders.build_loss_fun().cuda() | ||||
|     # Create data loaders | ||||
|     train_loader = loader.construct_train_loader() | ||||
|     test_loader = loader.construct_test_loader() | ||||
|     # Compute model and loader timings | ||||
|     benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) | ||||
		Reference in New Issue
	
	Block a user