Adding checkpointing and resuming
This commit is contained in:
parent
ad28862b48
commit
8779f62a1c
4
.gitignore
vendored
4
.gitignore
vendored
@ -303,4 +303,6 @@ tags
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
|
||||
|
||||
wandb/
|
||||
wandb/
|
||||
logs/
|
||||
notebooks/
|
@ -5,12 +5,16 @@ def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# misc
|
||||
config.run_name = ""
|
||||
config.seed = 42
|
||||
config.logdir = "logs"
|
||||
config.num_epochs = 100
|
||||
config.save_freq = 20
|
||||
config.num_checkpoint_limit = 5
|
||||
config.mixed_precision = "fp16"
|
||||
config.allow_tf32 = True
|
||||
config.use_lora = True
|
||||
config.resume_from = ""
|
||||
|
||||
# pretrained model initialization
|
||||
config.pretrained = pretrained = ml_collections.ConfigDict()
|
||||
|
@ -1,12 +1,13 @@
|
||||
from collections import defaultdict
|
||||
import contextlib
|
||||
import os
|
||||
import datetime
|
||||
from absl import app, flags
|
||||
from ml_collections import config_flags
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate.utils import set_seed, ProjectConfiguration
|
||||
from accelerate.logging import get_logger
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
import numpy as np
|
||||
@ -35,19 +36,45 @@ def main(_):
|
||||
# basic Accelerate and logging setup
|
||||
config = FLAGS.config
|
||||
|
||||
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if not config.run_name:
|
||||
config.run_name = unique_id
|
||||
else:
|
||||
config.run_name += "_" + unique_id
|
||||
|
||||
if config.resume_from:
|
||||
config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {config.resume_from}")
|
||||
config.resume_from = os.path.join(
|
||||
config.resume_from,
|
||||
sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
|
||||
)
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
|
||||
|
||||
accelerator_config = ProjectConfiguration(
|
||||
project_dir=os.path.join(config.logdir, config.run_name),
|
||||
automatic_checkpoint_naming=True,
|
||||
total_limit=config.num_checkpoint_limit,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with="wandb",
|
||||
mixed_precision=config.mixed_precision,
|
||||
project_dir=config.logdir,
|
||||
project_config=accelerator_config,
|
||||
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
||||
# _samples_ to accumulate across
|
||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
|
||||
accelerator.init_trackers(
|
||||
project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}}
|
||||
)
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
# set seed (device_specific is very important to get different prompts on different devices)
|
||||
@ -108,6 +135,40 @@ def main(_):
|
||||
else:
|
||||
trainable_layers = pipeline.unet
|
||||
|
||||
# set up diffusers-friendly checkpoint saving with Accelerate
|
||||
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
assert len(models) == 1
|
||||
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
pipeline.unet.save_attn_procs(output_dir)
|
||||
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
models[0].save_pretrained(os.path.join(output_dir, "unet"))
|
||||
else:
|
||||
raise ValueError(f"Unknown model type {type(models[0])}")
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
assert len(models) == 1
|
||||
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
# pipeline.unet.load_attn_procs(input_dir)
|
||||
tmp_unet = UNet2DConditionModel.from_pretrained(
|
||||
config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
|
||||
)
|
||||
tmp_unet.load_attn_procs(input_dir)
|
||||
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
|
||||
del tmp_unet
|
||||
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
models[0].register_to_config(**load_model.config)
|
||||
models[0].load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
else:
|
||||
raise ValueError(f"Unknown model type {type(models[0])}")
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if config.allow_tf32:
|
||||
@ -185,8 +246,15 @@ def main(_):
|
||||
assert config.sample.batch_size % config.train.batch_size == 0
|
||||
assert samples_per_epoch % total_train_batch_size == 0
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
accelerator.load_state(config.resume_from)
|
||||
first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
first_epoch = 0
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(config.num_epochs):
|
||||
for epoch in range(first_epoch, config.num_epochs):
|
||||
#################### SAMPLING ####################
|
||||
pipeline.unet.eval()
|
||||
samples = []
|
||||
@ -387,7 +455,9 @@ def main(_):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
|
||||
assert (j == num_train_timesteps - 1) and (
|
||||
i + 1
|
||||
) % config.train.gradient_accumulation_steps == 0
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = accelerator.reduce(info, reduction="mean")
|
||||
@ -399,6 +469,9 @@ def main(_):
|
||||
# make sure we did an optimization step at the end of the inner epoch
|
||||
assert accelerator.sync_gradients
|
||||
|
||||
if epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||
accelerator.save_state()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
Loading…
Reference in New Issue
Block a user