Adding checkpointing and resuming

This commit is contained in:
Kevin Black 2023-06-28 17:58:25 -07:00
parent ad28862b48
commit 8779f62a1c
3 changed files with 86 additions and 7 deletions

4
.gitignore vendored
View File

@ -303,4 +303,6 @@ tags
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
wandb/
wandb/
logs/
notebooks/

View File

@ -5,12 +5,16 @@ def get_config():
config = ml_collections.ConfigDict()
# misc
config.run_name = ""
config.seed = 42
config.logdir = "logs"
config.num_epochs = 100
config.save_freq = 20
config.num_checkpoint_limit = 5
config.mixed_precision = "fp16"
config.allow_tf32 = True
config.use_lora = True
config.resume_from = ""
# pretrained model initialization
config.pretrained = pretrained = ml_collections.ConfigDict()

View File

@ -1,12 +1,13 @@
from collections import defaultdict
import contextlib
import os
import datetime
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
@ -35,19 +36,45 @@ def main(_):
# basic Accelerate and logging setup
config = FLAGS.config
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if not config.run_name:
config.run_name = unique_id
else:
config.run_name += "_" + unique_id
if config.resume_from:
config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
if "checkpoint_" not in os.path.basename(config.resume_from):
# get the most recent checkpoint in this directory
checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
if len(checkpoints) == 0:
raise ValueError(f"No checkpoints found in {config.resume_from}")
config.resume_from = os.path.join(
config.resume_from,
sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
)
# number of timesteps within each trajectory to train on
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
accelerator_config = ProjectConfiguration(
project_dir=os.path.join(config.logdir, config.run_name),
automatic_checkpoint_naming=True,
total_limit=config.num_checkpoint_limit,
)
accelerator = Accelerator(
log_with="wandb",
mixed_precision=config.mixed_precision,
project_dir=config.logdir,
project_config=accelerator_config,
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
# _samples_ to accumulate across
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
)
if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
accelerator.init_trackers(
project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}}
)
logger.info(f"\n{config}")
# set seed (device_specific is very important to get different prompts on different devices)
@ -108,6 +135,40 @@ def main(_):
else:
trainable_layers = pipeline.unet
# set up diffusers-friendly checkpoint saving with Accelerate
def save_model_hook(models, weights, output_dir):
assert len(models) == 1
if config.use_lora and isinstance(models[0], AttnProcsLayers):
pipeline.unet.save_attn_procs(output_dir)
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
models[0].save_pretrained(os.path.join(output_dir, "unet"))
else:
raise ValueError(f"Unknown model type {type(models[0])}")
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
def load_model_hook(models, input_dir):
assert len(models) == 1
if config.use_lora and isinstance(models[0], AttnProcsLayers):
# pipeline.unet.load_attn_procs(input_dir)
tmp_unet = UNet2DConditionModel.from_pretrained(
config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
)
tmp_unet.load_attn_procs(input_dir)
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
del tmp_unet
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
models[0].register_to_config(**load_model.config)
models[0].load_state_dict(load_model.state_dict())
del load_model
else:
raise ValueError(f"Unknown model type {type(models[0])}")
models.pop() # ensures that accelerate doesn't try to handle loading of the model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if config.allow_tf32:
@ -185,8 +246,15 @@ def main(_):
assert config.sample.batch_size % config.train.batch_size == 0
assert samples_per_epoch % total_train_batch_size == 0
if config.resume_from:
logger.info(f"Resuming from {config.resume_from}")
accelerator.load_state(config.resume_from)
first_epoch = int(config.resume_from.split("_")[-1]) + 1
else:
first_epoch = 0
global_step = 0
for epoch in range(config.num_epochs):
for epoch in range(first_epoch, config.num_epochs):
#################### SAMPLING ####################
pipeline.unet.eval()
samples = []
@ -387,7 +455,9 @@ def main(_):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
assert (j == num_train_timesteps - 1) and (
i + 1
) % config.train.gradient_accumulation_steps == 0
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = accelerator.reduce(info, reduction="mean")
@ -399,6 +469,9 @@ def main(_):
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if epoch % config.save_freq == 0 and accelerator.is_main_process:
accelerator.save_state()
if __name__ == "__main__":
app.run(main)