99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Functions that handle saving and loading of checkpoints."""
|
|
|
|
import os
|
|
|
|
import pycls.core.distributed as dist
|
|
import torch
|
|
from pycls.core.config import cfg
|
|
|
|
|
|
# Common prefix for checkpoint file names
|
|
_NAME_PREFIX = "model_epoch_"
|
|
# Checkpoints directory name
|
|
_DIR_NAME = "checkpoints"
|
|
|
|
|
|
def get_checkpoint_dir():
|
|
"""Retrieves the location for storing checkpoints."""
|
|
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
|
|
|
|
|
|
def get_checkpoint(epoch):
|
|
"""Retrieves the path to a checkpoint file."""
|
|
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
|
|
return os.path.join(get_checkpoint_dir(), name)
|
|
|
|
|
|
def get_last_checkpoint():
|
|
"""Retrieves the most recent checkpoint (highest epoch number)."""
|
|
checkpoint_dir = get_checkpoint_dir()
|
|
# Checkpoint file names are in lexicographic order
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
|
|
last_checkpoint_name = sorted(checkpoints)[-1]
|
|
return os.path.join(checkpoint_dir, last_checkpoint_name)
|
|
|
|
|
|
def has_checkpoint():
|
|
"""Determines if there are checkpoints available."""
|
|
checkpoint_dir = get_checkpoint_dir()
|
|
if not os.path.exists(checkpoint_dir):
|
|
return False
|
|
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
|
|
|
|
|
|
def save_checkpoint(model, optimizer, epoch):
|
|
"""Saves a checkpoint."""
|
|
# Save checkpoints only from the master process
|
|
if not dist.is_master_proc():
|
|
return
|
|
# Ensure that the checkpoint dir exists
|
|
os.makedirs(get_checkpoint_dir(), exist_ok=True)
|
|
# Omit the DDP wrapper in the multi-gpu setting
|
|
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
|
|
# Record the state
|
|
if isinstance(optimizer, list):
|
|
checkpoint = {
|
|
"epoch": epoch,
|
|
"model_state": sd,
|
|
"optimizer_w_state": optimizer[0].state_dict(),
|
|
"optimizer_a_state": optimizer[1].state_dict(),
|
|
"cfg": cfg.dump(),
|
|
}
|
|
else:
|
|
checkpoint = {
|
|
"epoch": epoch,
|
|
"model_state": sd,
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"cfg": cfg.dump(),
|
|
}
|
|
# Write the checkpoint
|
|
checkpoint_file = get_checkpoint(epoch + 1)
|
|
torch.save(checkpoint, checkpoint_file)
|
|
return checkpoint_file
|
|
|
|
|
|
def load_checkpoint(checkpoint_file, model, optimizer=None):
|
|
"""Loads the checkpoint from the given file."""
|
|
err_str = "Checkpoint '{}' not found"
|
|
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
|
|
# Load the checkpoint on CPU to avoid GPU mem spike
|
|
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
|
# Account for the DDP wrapper in the multi-gpu setting
|
|
ms = model.module if cfg.NUM_GPUS > 1 else model
|
|
ms.load_state_dict(checkpoint["model_state"])
|
|
# Load the optimizer state (commonly not done when fine-tuning)
|
|
if optimizer:
|
|
if isinstance(optimizer, list):
|
|
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
|
|
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
|
|
else:
|
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
return checkpoint["epoch"]
|