Adding checkpointing and resuming
This commit is contained in:
		| @@ -1,12 +1,13 @@ | ||||
| from collections import defaultdict | ||||
| import contextlib | ||||
| import os | ||||
| import datetime | ||||
| from absl import app, flags | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed | ||||
| from accelerate.utils import set_seed, ProjectConfiguration | ||||
| from accelerate.logging import get_logger | ||||
| from diffusers import StableDiffusionPipeline, DDIMScheduler | ||||
| from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel | ||||
| from diffusers.loaders import AttnProcsLayers | ||||
| from diffusers.models.attention_processor import LoRAAttnProcessor | ||||
| import numpy as np | ||||
| @@ -35,19 +36,45 @@ def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|  | ||||
|     unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||||
|     if not config.run_name: | ||||
|         config.run_name = unique_id | ||||
|     else: | ||||
|         config.run_name += "_" + unique_id | ||||
|  | ||||
|     if config.resume_from: | ||||
|         config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from)) | ||||
|         if "checkpoint_" not in os.path.basename(config.resume_from): | ||||
|             # get the most recent checkpoint in this directory | ||||
|             checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))) | ||||
|             if len(checkpoints) == 0: | ||||
|                 raise ValueError(f"No checkpoints found in {config.resume_from}") | ||||
|             config.resume_from = os.path.join( | ||||
|                 config.resume_from, | ||||
|                 sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1], | ||||
|             ) | ||||
|  | ||||
|     # number of timesteps within each trajectory to train on | ||||
|     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) | ||||
|  | ||||
|     accelerator_config = ProjectConfiguration( | ||||
|         project_dir=os.path.join(config.logdir, config.run_name), | ||||
|         automatic_checkpoint_naming=True, | ||||
|         total_limit=config.num_checkpoint_limit, | ||||
|     ) | ||||
|  | ||||
|     accelerator = Accelerator( | ||||
|         log_with="wandb", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_dir=config.logdir, | ||||
|         project_config=accelerator_config, | ||||
|         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||
|         # _samples_ to accumulate across | ||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) | ||||
|         accelerator.init_trackers( | ||||
|             project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}} | ||||
|         ) | ||||
|     logger.info(f"\n{config}") | ||||
|  | ||||
|     # set seed (device_specific is very important to get different prompts on different devices) | ||||
| @@ -108,6 +135,40 @@ def main(_): | ||||
|     else: | ||||
|         trainable_layers = pipeline.unet | ||||
|  | ||||
|     # set up diffusers-friendly checkpoint saving with Accelerate | ||||
|  | ||||
|     def save_model_hook(models, weights, output_dir): | ||||
|         assert len(models) == 1 | ||||
|         if config.use_lora and isinstance(models[0], AttnProcsLayers): | ||||
|             pipeline.unet.save_attn_procs(output_dir) | ||||
|         elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): | ||||
|             models[0].save_pretrained(os.path.join(output_dir, "unet")) | ||||
|         else: | ||||
|             raise ValueError(f"Unknown model type {type(models[0])}") | ||||
|         weights.pop()  # ensures that accelerate doesn't try to handle saving of the model | ||||
|  | ||||
|     def load_model_hook(models, input_dir): | ||||
|         assert len(models) == 1 | ||||
|         if config.use_lora and isinstance(models[0], AttnProcsLayers): | ||||
|             # pipeline.unet.load_attn_procs(input_dir) | ||||
|             tmp_unet = UNet2DConditionModel.from_pretrained( | ||||
|                 config.pretrained.model, revision=config.pretrained.revision, subfolder="unet" | ||||
|             ) | ||||
|             tmp_unet.load_attn_procs(input_dir) | ||||
|             models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict()) | ||||
|             del tmp_unet | ||||
|         elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): | ||||
|             load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") | ||||
|             models[0].register_to_config(**load_model.config) | ||||
|             models[0].load_state_dict(load_model.state_dict()) | ||||
|             del load_model | ||||
|         else: | ||||
|             raise ValueError(f"Unknown model type {type(models[0])}") | ||||
|         models.pop()  # ensures that accelerate doesn't try to handle loading of the model | ||||
|  | ||||
|     accelerator.register_save_state_pre_hook(save_model_hook) | ||||
|     accelerator.register_load_state_pre_hook(load_model_hook) | ||||
|  | ||||
|     # Enable TF32 for faster training on Ampere GPUs, | ||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||
|     if config.allow_tf32: | ||||
| @@ -185,8 +246,15 @@ def main(_): | ||||
|     assert config.sample.batch_size % config.train.batch_size == 0 | ||||
|     assert samples_per_epoch % total_train_batch_size == 0 | ||||
|  | ||||
|     if config.resume_from: | ||||
|         logger.info(f"Resuming from {config.resume_from}") | ||||
|         accelerator.load_state(config.resume_from) | ||||
|         first_epoch = int(config.resume_from.split("_")[-1]) + 1 | ||||
|     else: | ||||
|         first_epoch = 0 | ||||
|  | ||||
|     global_step = 0 | ||||
|     for epoch in range(config.num_epochs): | ||||
|     for epoch in range(first_epoch, config.num_epochs): | ||||
|         #################### SAMPLING #################### | ||||
|         pipeline.unet.eval() | ||||
|         samples = [] | ||||
| @@ -387,7 +455,9 @@ def main(_): | ||||
|  | ||||
|                     # Checks if the accelerator has performed an optimization step behind the scenes | ||||
|                     if accelerator.sync_gradients: | ||||
|                         assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 | ||||
|                         assert (j == num_train_timesteps - 1) and ( | ||||
|                             i + 1 | ||||
|                         ) % config.train.gradient_accumulation_steps == 0 | ||||
|                         # log training-related stuff | ||||
|                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | ||||
|                         info = accelerator.reduce(info, reduction="mean") | ||||
| @@ -399,6 +469,9 @@ def main(_): | ||||
|             # make sure we did an optimization step at the end of the inner epoch | ||||
|             assert accelerator.sync_gradients | ||||
|  | ||||
|         if epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||
|             accelerator.save_state() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     app.run(main) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user