Adding checkpointing and resuming
This commit is contained in:
parent
ad28862b48
commit
8779f62a1c
4
.gitignore
vendored
4
.gitignore
vendored
@ -303,4 +303,6 @@ tags
|
|||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
|
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
|
||||||
|
|
||||||
wandb/
|
wandb/
|
||||||
|
logs/
|
||||||
|
notebooks/
|
@ -5,12 +5,16 @@ def get_config():
|
|||||||
config = ml_collections.ConfigDict()
|
config = ml_collections.ConfigDict()
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
|
config.run_name = ""
|
||||||
config.seed = 42
|
config.seed = 42
|
||||||
config.logdir = "logs"
|
config.logdir = "logs"
|
||||||
config.num_epochs = 100
|
config.num_epochs = 100
|
||||||
|
config.save_freq = 20
|
||||||
|
config.num_checkpoint_limit = 5
|
||||||
config.mixed_precision = "fp16"
|
config.mixed_precision = "fp16"
|
||||||
config.allow_tf32 = True
|
config.allow_tf32 = True
|
||||||
config.use_lora = True
|
config.use_lora = True
|
||||||
|
config.resume_from = ""
|
||||||
|
|
||||||
# pretrained model initialization
|
# pretrained model initialization
|
||||||
config.pretrained = pretrained = ml_collections.ConfigDict()
|
config.pretrained = pretrained = ml_collections.ConfigDict()
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
import datetime
|
||||||
from absl import app, flags
|
from absl import app, flags
|
||||||
from ml_collections import config_flags
|
from ml_collections import config_flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed, ProjectConfiguration
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
|
||||||
from diffusers.loaders import AttnProcsLayers
|
from diffusers.loaders import AttnProcsLayers
|
||||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -35,19 +36,45 @@ def main(_):
|
|||||||
# basic Accelerate and logging setup
|
# basic Accelerate and logging setup
|
||||||
config = FLAGS.config
|
config = FLAGS.config
|
||||||
|
|
||||||
|
unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
if not config.run_name:
|
||||||
|
config.run_name = unique_id
|
||||||
|
else:
|
||||||
|
config.run_name += "_" + unique_id
|
||||||
|
|
||||||
|
if config.resume_from:
|
||||||
|
config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
|
||||||
|
if "checkpoint_" not in os.path.basename(config.resume_from):
|
||||||
|
# get the most recent checkpoint in this directory
|
||||||
|
checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
|
||||||
|
if len(checkpoints) == 0:
|
||||||
|
raise ValueError(f"No checkpoints found in {config.resume_from}")
|
||||||
|
config.resume_from = os.path.join(
|
||||||
|
config.resume_from,
|
||||||
|
sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
|
||||||
|
)
|
||||||
|
|
||||||
# number of timesteps within each trajectory to train on
|
# number of timesteps within each trajectory to train on
|
||||||
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
|
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
|
||||||
|
|
||||||
|
accelerator_config = ProjectConfiguration(
|
||||||
|
project_dir=os.path.join(config.logdir, config.run_name),
|
||||||
|
automatic_checkpoint_naming=True,
|
||||||
|
total_limit=config.num_checkpoint_limit,
|
||||||
|
)
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
log_with="wandb",
|
log_with="wandb",
|
||||||
mixed_precision=config.mixed_precision,
|
mixed_precision=config.mixed_precision,
|
||||||
project_dir=config.logdir,
|
project_config=accelerator_config,
|
||||||
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
# we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
|
||||||
# _samples_ to accumulate across
|
# _samples_ to accumulate across
|
||||||
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
|
||||||
)
|
)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
|
accelerator.init_trackers(
|
||||||
|
project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}}
|
||||||
|
)
|
||||||
logger.info(f"\n{config}")
|
logger.info(f"\n{config}")
|
||||||
|
|
||||||
# set seed (device_specific is very important to get different prompts on different devices)
|
# set seed (device_specific is very important to get different prompts on different devices)
|
||||||
@ -108,6 +135,40 @@ def main(_):
|
|||||||
else:
|
else:
|
||||||
trainable_layers = pipeline.unet
|
trainable_layers = pipeline.unet
|
||||||
|
|
||||||
|
# set up diffusers-friendly checkpoint saving with Accelerate
|
||||||
|
|
||||||
|
def save_model_hook(models, weights, output_dir):
|
||||||
|
assert len(models) == 1
|
||||||
|
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||||
|
pipeline.unet.save_attn_procs(output_dir)
|
||||||
|
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||||
|
models[0].save_pretrained(os.path.join(output_dir, "unet"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type {type(models[0])}")
|
||||||
|
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||||
|
|
||||||
|
def load_model_hook(models, input_dir):
|
||||||
|
assert len(models) == 1
|
||||||
|
if config.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||||
|
# pipeline.unet.load_attn_procs(input_dir)
|
||||||
|
tmp_unet = UNet2DConditionModel.from_pretrained(
|
||||||
|
config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
|
||||||
|
)
|
||||||
|
tmp_unet.load_attn_procs(input_dir)
|
||||||
|
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
|
||||||
|
del tmp_unet
|
||||||
|
elif not config.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||||
|
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||||
|
models[0].register_to_config(**load_model.config)
|
||||||
|
models[0].load_state_dict(load_model.state_dict())
|
||||||
|
del load_model
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type {type(models[0])}")
|
||||||
|
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||||
|
|
||||||
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
|
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
|
|
||||||
# Enable TF32 for faster training on Ampere GPUs,
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
if config.allow_tf32:
|
if config.allow_tf32:
|
||||||
@ -185,8 +246,15 @@ def main(_):
|
|||||||
assert config.sample.batch_size % config.train.batch_size == 0
|
assert config.sample.batch_size % config.train.batch_size == 0
|
||||||
assert samples_per_epoch % total_train_batch_size == 0
|
assert samples_per_epoch % total_train_batch_size == 0
|
||||||
|
|
||||||
|
if config.resume_from:
|
||||||
|
logger.info(f"Resuming from {config.resume_from}")
|
||||||
|
accelerator.load_state(config.resume_from)
|
||||||
|
first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||||
|
else:
|
||||||
|
first_epoch = 0
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
for epoch in range(config.num_epochs):
|
for epoch in range(first_epoch, config.num_epochs):
|
||||||
#################### SAMPLING ####################
|
#################### SAMPLING ####################
|
||||||
pipeline.unet.eval()
|
pipeline.unet.eval()
|
||||||
samples = []
|
samples = []
|
||||||
@ -387,7 +455,9 @@ def main(_):
|
|||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
|
assert (j == num_train_timesteps - 1) and (
|
||||||
|
i + 1
|
||||||
|
) % config.train.gradient_accumulation_steps == 0
|
||||||
# log training-related stuff
|
# log training-related stuff
|
||||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||||
info = accelerator.reduce(info, reduction="mean")
|
info = accelerator.reduce(info, reduction="mean")
|
||||||
@ -399,6 +469,9 @@ def main(_):
|
|||||||
# make sure we did an optimization step at the end of the inner epoch
|
# make sure we did an optimization step at the end of the inner epoch
|
||||||
assert accelerator.sync_gradients
|
assert accelerator.sync_gradients
|
||||||
|
|
||||||
|
if epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||||
|
accelerator.save_state()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(main)
|
app.run(main)
|
||||||
|
Loading…
Reference in New Issue
Block a user