Adding checkpointing and resuming
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -304,3 +304,5 @@ tags | |||||||
| # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||||
|  |  | ||||||
| wandb/ | wandb/ | ||||||
|  | logs/ | ||||||
|  | notebooks/ | ||||||
| @@ -5,12 +5,16 @@ def get_config(): | |||||||
|     config = ml_collections.ConfigDict() |     config = ml_collections.ConfigDict() | ||||||
|  |  | ||||||
|     # misc |     # misc | ||||||
|  |     config.run_name = "" | ||||||
|     config.seed = 42 |     config.seed = 42 | ||||||
|     config.logdir = "logs" |     config.logdir = "logs" | ||||||
|     config.num_epochs = 100 |     config.num_epochs = 100 | ||||||
|  |     config.save_freq = 20 | ||||||
|  |     config.num_checkpoint_limit = 5 | ||||||
|     config.mixed_precision = "fp16" |     config.mixed_precision = "fp16" | ||||||
|     config.allow_tf32 = True |     config.allow_tf32 = True | ||||||
|     config.use_lora = True |     config.use_lora = True | ||||||
|  |     config.resume_from = "" | ||||||
|  |  | ||||||
|     # pretrained model initialization |     # pretrained model initialization | ||||||
|     config.pretrained = pretrained = ml_collections.ConfigDict() |     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||||
|   | |||||||
| @@ -1,12 +1,13 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| import contextlib | import contextlib | ||||||
| import os | import os | ||||||
|  | import datetime | ||||||
| from absl import app, flags | from absl import app, flags | ||||||
| from ml_collections import config_flags | from ml_collections import config_flags | ||||||
| from accelerate import Accelerator | from accelerate import Accelerator | ||||||
| from accelerate.utils import set_seed | from accelerate.utils import set_seed, ProjectConfiguration | ||||||
| from accelerate.logging import get_logger | from accelerate.logging import get_logger | ||||||
| from diffusers import StableDiffusionPipeline, DDIMScheduler | from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel | ||||||
| from diffusers.loaders import AttnProcsLayers | from diffusers.loaders import AttnProcsLayers | ||||||
| from diffusers.models.attention_processor import LoRAAttnProcessor | from diffusers.models.attention_processor import LoRAAttnProcessor | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -35,19 +36,45 @@ def main(_): | |||||||
|     # basic Accelerate and logging setup |     # basic Accelerate and logging setup | ||||||
|     config = FLAGS.config |     config = FLAGS.config | ||||||
|  |  | ||||||
|  |     unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||||||
|  |     if not config.run_name: | ||||||
|  |         config.run_name = unique_id | ||||||
|  |     else: | ||||||
|  |         config.run_name += "_" + unique_id | ||||||
|  |  | ||||||
|  |     if config.resume_from: | ||||||
|  |         config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from)) | ||||||
|  |         if "checkpoint_" not in os.path.basename(config.resume_from): | ||||||
|  |             # get the most recent checkpoint in this directory | ||||||
|  |             checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from))) | ||||||
|  |             if len(checkpoints) == 0: | ||||||
|  |                 raise ValueError(f"No checkpoints found in {config.resume_from}") | ||||||
|  |             config.resume_from = os.path.join( | ||||||
|  |                 config.resume_from, | ||||||
|  |                 sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1], | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     # number of timesteps within each trajectory to train on |     # number of timesteps within each trajectory to train on | ||||||
|     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) |     num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) | ||||||
|  |  | ||||||
|  |     accelerator_config = ProjectConfiguration( | ||||||
|  |         project_dir=os.path.join(config.logdir, config.run_name), | ||||||
|  |         automatic_checkpoint_naming=True, | ||||||
|  |         total_limit=config.num_checkpoint_limit, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     accelerator = Accelerator( |     accelerator = Accelerator( | ||||||
|         log_with="wandb", |         log_with="wandb", | ||||||
|         mixed_precision=config.mixed_precision, |         mixed_precision=config.mixed_precision, | ||||||
|         project_dir=config.logdir, |         project_config=accelerator_config, | ||||||
|         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of |         # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of | ||||||
|         # _samples_ to accumulate across |         # _samples_ to accumulate across | ||||||
|         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, |         gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, | ||||||
|     ) |     ) | ||||||
|     if accelerator.is_main_process: |     if accelerator.is_main_process: | ||||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) |         accelerator.init_trackers( | ||||||
|  |             project_name="ddpo-pytorch", config=config.to_dict(), init_kwargs={"wandb": {"name": config.run_name}} | ||||||
|  |         ) | ||||||
|     logger.info(f"\n{config}") |     logger.info(f"\n{config}") | ||||||
|  |  | ||||||
|     # set seed (device_specific is very important to get different prompts on different devices) |     # set seed (device_specific is very important to get different prompts on different devices) | ||||||
| @@ -108,6 +135,40 @@ def main(_): | |||||||
|     else: |     else: | ||||||
|         trainable_layers = pipeline.unet |         trainable_layers = pipeline.unet | ||||||
|  |  | ||||||
|  |     # set up diffusers-friendly checkpoint saving with Accelerate | ||||||
|  |  | ||||||
|  |     def save_model_hook(models, weights, output_dir): | ||||||
|  |         assert len(models) == 1 | ||||||
|  |         if config.use_lora and isinstance(models[0], AttnProcsLayers): | ||||||
|  |             pipeline.unet.save_attn_procs(output_dir) | ||||||
|  |         elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): | ||||||
|  |             models[0].save_pretrained(os.path.join(output_dir, "unet")) | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Unknown model type {type(models[0])}") | ||||||
|  |         weights.pop()  # ensures that accelerate doesn't try to handle saving of the model | ||||||
|  |  | ||||||
|  |     def load_model_hook(models, input_dir): | ||||||
|  |         assert len(models) == 1 | ||||||
|  |         if config.use_lora and isinstance(models[0], AttnProcsLayers): | ||||||
|  |             # pipeline.unet.load_attn_procs(input_dir) | ||||||
|  |             tmp_unet = UNet2DConditionModel.from_pretrained( | ||||||
|  |                 config.pretrained.model, revision=config.pretrained.revision, subfolder="unet" | ||||||
|  |             ) | ||||||
|  |             tmp_unet.load_attn_procs(input_dir) | ||||||
|  |             models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict()) | ||||||
|  |             del tmp_unet | ||||||
|  |         elif not config.use_lora and isinstance(models[0], UNet2DConditionModel): | ||||||
|  |             load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") | ||||||
|  |             models[0].register_to_config(**load_model.config) | ||||||
|  |             models[0].load_state_dict(load_model.state_dict()) | ||||||
|  |             del load_model | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Unknown model type {type(models[0])}") | ||||||
|  |         models.pop()  # ensures that accelerate doesn't try to handle loading of the model | ||||||
|  |  | ||||||
|  |     accelerator.register_save_state_pre_hook(save_model_hook) | ||||||
|  |     accelerator.register_load_state_pre_hook(load_model_hook) | ||||||
|  |  | ||||||
|     # Enable TF32 for faster training on Ampere GPUs, |     # Enable TF32 for faster training on Ampere GPUs, | ||||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices |     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||||
|     if config.allow_tf32: |     if config.allow_tf32: | ||||||
| @@ -185,8 +246,15 @@ def main(_): | |||||||
|     assert config.sample.batch_size % config.train.batch_size == 0 |     assert config.sample.batch_size % config.train.batch_size == 0 | ||||||
|     assert samples_per_epoch % total_train_batch_size == 0 |     assert samples_per_epoch % total_train_batch_size == 0 | ||||||
|  |  | ||||||
|  |     if config.resume_from: | ||||||
|  |         logger.info(f"Resuming from {config.resume_from}") | ||||||
|  |         accelerator.load_state(config.resume_from) | ||||||
|  |         first_epoch = int(config.resume_from.split("_")[-1]) + 1 | ||||||
|  |     else: | ||||||
|  |         first_epoch = 0 | ||||||
|  |  | ||||||
|     global_step = 0 |     global_step = 0 | ||||||
|     for epoch in range(config.num_epochs): |     for epoch in range(first_epoch, config.num_epochs): | ||||||
|         #################### SAMPLING #################### |         #################### SAMPLING #################### | ||||||
|         pipeline.unet.eval() |         pipeline.unet.eval() | ||||||
|         samples = [] |         samples = [] | ||||||
| @@ -387,7 +455,9 @@ def main(_): | |||||||
|  |  | ||||||
|                     # Checks if the accelerator has performed an optimization step behind the scenes |                     # Checks if the accelerator has performed an optimization step behind the scenes | ||||||
|                     if accelerator.sync_gradients: |                     if accelerator.sync_gradients: | ||||||
|                         assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0 |                         assert (j == num_train_timesteps - 1) and ( | ||||||
|  |                             i + 1 | ||||||
|  |                         ) % config.train.gradient_accumulation_steps == 0 | ||||||
|                         # log training-related stuff |                         # log training-related stuff | ||||||
|                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} |                         info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} | ||||||
|                         info = accelerator.reduce(info, reduction="mean") |                         info = accelerator.reduce(info, reduction="mean") | ||||||
| @@ -399,6 +469,9 @@ def main(_): | |||||||
|             # make sure we did an optimization step at the end of the inner epoch |             # make sure we did an optimization step at the end of the inner epoch | ||||||
|             assert accelerator.sync_gradients |             assert accelerator.sync_gradients | ||||||
|  |  | ||||||
|  |         if epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||||
|  |             accelerator.save_state() | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     app.run(main) |     app.run(main) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user