#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Functions that handle saving and loading of checkpoints.""" import os import pycls.core.distributed as dist import torch from pycls.core.config import cfg # Common prefix for checkpoint file names _NAME_PREFIX = "model_epoch_" # Checkpoints directory name _DIR_NAME = "checkpoints" def get_checkpoint_dir(): """Retrieves the location for storing checkpoints.""" return os.path.join(cfg.OUT_DIR, _DIR_NAME) def get_checkpoint(epoch): """Retrieves the path to a checkpoint file.""" name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) return os.path.join(get_checkpoint_dir(), name) def get_last_checkpoint(): """Retrieves the most recent checkpoint (highest epoch number).""" checkpoint_dir = get_checkpoint_dir() # Checkpoint file names are in lexicographic order checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] last_checkpoint_name = sorted(checkpoints)[-1] return os.path.join(checkpoint_dir, last_checkpoint_name) def has_checkpoint(): """Determines if there are checkpoints available.""" checkpoint_dir = get_checkpoint_dir() if not os.path.exists(checkpoint_dir): return False return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) def save_checkpoint(model, optimizer, epoch): """Saves a checkpoint.""" # Save checkpoints only from the master process if not dist.is_master_proc(): return # Ensure that the checkpoint dir exists os.makedirs(get_checkpoint_dir(), exist_ok=True) # Omit the DDP wrapper in the multi-gpu setting sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict() # Record the state if isinstance(optimizer, list): checkpoint = { "epoch": epoch, "model_state": sd, "optimizer_w_state": optimizer[0].state_dict(), "optimizer_a_state": optimizer[1].state_dict(), "cfg": cfg.dump(), } else: checkpoint = { "epoch": epoch, "model_state": sd, "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) torch.save(checkpoint, checkpoint_file) return checkpoint_file def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) # Load the checkpoint on CPU to avoid GPU mem spike checkpoint = torch.load(checkpoint_file, map_location="cpu") # Account for the DDP wrapper in the multi-gpu setting ms = model.module if cfg.NUM_GPUS > 1 else model ms.load_state_dict(checkpoint["model_state"]) # Load the optimizer state (commonly not done when fine-tuning) if optimizer: if isinstance(optimizer, list): optimizer[0].load_state_dict(checkpoint["optimizer_w_state"]) optimizer[1].load_state_dict(checkpoint["optimizer_a_state"]) else: optimizer.load_state_dict(checkpoint["optimizer_state"]) return checkpoint["epoch"]