naswot/pycls/core/checkpoint.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

99 lines
3.3 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Functions that handle saving and loading of checkpoints."""
import os
import pycls.core.distributed as dist
import torch
from pycls.core.config import cfg
# Common prefix for checkpoint file names
_NAME_PREFIX = "model_epoch_"
# Checkpoints directory name
_DIR_NAME = "checkpoints"
def get_checkpoint_dir():
"""Retrieves the location for storing checkpoints."""
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
def get_checkpoint(epoch):
"""Retrieves the path to a checkpoint file."""
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
return os.path.join(get_checkpoint_dir(), name)
def get_last_checkpoint():
"""Retrieves the most recent checkpoint (highest epoch number)."""
checkpoint_dir = get_checkpoint_dir()
# Checkpoint file names are in lexicographic order
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
last_checkpoint_name = sorted(checkpoints)[-1]
return os.path.join(checkpoint_dir, last_checkpoint_name)
def has_checkpoint():
"""Determines if there are checkpoints available."""
checkpoint_dir = get_checkpoint_dir()
if not os.path.exists(checkpoint_dir):
return False
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
def save_checkpoint(model, optimizer, epoch):
"""Saves a checkpoint."""
# Save checkpoints only from the master process
if not dist.is_master_proc():
return
# Ensure that the checkpoint dir exists
os.makedirs(get_checkpoint_dir(), exist_ok=True)
# Omit the DDP wrapper in the multi-gpu setting
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
# Record the state
if isinstance(optimizer, list):
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_w_state": optimizer[0].state_dict(),
"optimizer_a_state": optimizer[1].state_dict(),
"cfg": cfg.dump(),
}
else:
checkpoint = {
"epoch": epoch,
"model_state": sd,
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
torch.save(checkpoint, checkpoint_file)
return checkpoint_file
def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
err_str = "Checkpoint '{}' not found"
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
# Load the checkpoint on CPU to avoid GPU mem spike
checkpoint = torch.load(checkpoint_file, map_location="cpu")
# Account for the DDP wrapper in the multi-gpu setting
ms = model.module if cfg.NUM_GPUS > 1 else model
ms.load_state_dict(checkpoint["model_state"])
# Load the optimizer state (commonly not done when fine-tuning)
if optimizer:
if isinstance(optimizer, list):
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
else:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"]