From c0bc708549db278c33e1a5d9b06493a17aeb9ee5 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Thu, 29 Jun 2023 00:51:38 -0700 Subject: [PATCH] Commenting pass --- config/base.py | 103 +++++++++++++++++++++++++++++++++++------------ scripts/train.py | 9 +++-- 2 files changed, 84 insertions(+), 28 deletions(-) diff --git a/config/base.py b/config/base.py index 6bc59ca..c9601cd 100644 --- a/config/base.py +++ b/config/base.py @@ -1,60 +1,113 @@ import ml_collections -def get_config(): +def get_config(): config = ml_collections.ConfigDict() - # misc + ###### General ###### + # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime. config.run_name = "" + # random seed for reproducibility. config.seed = 42 + # top-level logging directory for checkpoint saving. config.logdir = "logs" + # number of epochs to train for. each epoch is one round of sampling from the model followed by training on those + # samples. config.num_epochs = 100 + # number of epochs between saving model checkpoints. config.save_freq = 20 + # number of checkpoints to keep before overwriting old ones. config.num_checkpoint_limit = 5 + # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly. config.mixed_precision = "fp16" + # allow tf32 on Ampere GPUs, which can speed up training. config.allow_tf32 = True - config.use_lora = True + # resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory + # containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value + # as the run that generated the saved checkpoint. config.resume_from = "" + # whether or not to use LoRA. LoRA reduces memory usage significantly by injecting small weight matrices into the + # attention layers of the UNet. with LoRA, fp16, and a batch size of 1, finetuning Stable Diffusion should take + # about 10GB of GPU memory. beware that if LoRA is disabled, training will take a lot of memory and saved checkpoint + # files will also be large. + config.use_lora = True - # pretrained model initialization + ###### Pretrained Model ###### config.pretrained = pretrained = ml_collections.ConfigDict() + # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub. pretrained.model = "runwayml/stable-diffusion-v1-5" + # revision of the model to load. pretrained.revision = "main" - # training - config.train = train = ml_collections.ConfigDict() - train.batch_size = 1 - train.use_8bit_adam = False - train.learning_rate = 1e-4 - train.adam_beta1 = 0.9 - train.adam_beta2 = 0.999 - train.adam_weight_decay = 1e-4 - train.adam_epsilon = 1e-8 - train.gradient_accumulation_steps = 1 - train.max_grad_norm = 1.0 - train.num_inner_epochs = 1 - train.cfg = True - train.adv_clip_max = 10 - train.clip_range = 1e-4 - train.timestep_fraction = 1.0 - - # sampling + ###### Sampling ###### config.sample = sample = ml_collections.ConfigDict() + # number of sampler inference steps. sample.num_steps = 10 + # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0 + # being fully deterministic and 1.0 being equivalent to the DDPM sampler. sample.eta = 1.0 + # classifier-free guidance weight. 1.0 is no guidance. sample.guidance_scale = 5.0 + # batch size (per GPU!) to use for sampling. sample.batch_size = 1 + # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch * + # batch_size * num_gpus`. sample.num_batches_per_epoch = 2 - # prompting + ###### Training ###### + config.train = train = ml_collections.ConfigDict() + # batch size (per GPU!) to use for training. + train.batch_size = 1 + # whether to use the 8bit Adam optimizer from bitsandbytes. + train.use_8bit_adam = False + # learning rate. + train.learning_rate = 1e-4 + # Adam beta1. + train.adam_beta1 = 0.9 + # Adam beta2. + train.adam_beta2 = 0.999 + # Adam weight decay. + train.adam_weight_decay = 1e-4 + # Adam epsilon. + train.adam_epsilon = 1e-8 + # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus * + # gradient_accumulation_steps`. + train.gradient_accumulation_steps = 1 + # maximum gradient norm for gradient clipping. + train.max_grad_norm = 1.0 + # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one + # outer epoch's round of sampling. + train.num_inner_epochs = 1 + # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during + # sampling will be used during training. + train.cfg = True + # clip advantages to the range [-adv_clip_max, adv_clip_max]. + train.adv_clip_max = 10 + # the PPO clip range. + train.clip_range = 1e-4 + # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the + # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates. + train.timestep_fraction = 1.0 + + ###### Prompt Function ###### + # prompt function to use. see `prompts.py` for available prompt functions. config.prompt_fn = "imagenet_animals" + # kwargs to pass to the prompt function. config.prompt_fn_kwargs = {} - # rewards + ###### Reward Function ###### + # reward function to use. see `rewards.py` for available reward functions. config.reward_fn = "jpeg_compressibility" + ###### Per-Prompt Stat Tracking ###### + # when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute + # advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case + # advantages will be calculated using the mean and std of the entire batch. config.per_prompt_stat_tracking = ml_collections.ConfigDict() + # number of reward values to store in the buffer for each prompt. the buffer persists across epochs. config.per_prompt_stat_tracking.buffer_size = 16 + # the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer + # contains fewer than `min_count` values, the mean and std of the entire batch will be used instead. config.per_prompt_stat_tracking.min_count = 16 - return config \ No newline at end of file + return config diff --git a/scripts/train.py b/scripts/train.py index fc458c4..f5928d5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -36,7 +36,7 @@ def main(_): # basic Accelerate and logging setup config = FLAGS.config - unique_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S") if not config.run_name: config.run_name = unique_id else: @@ -67,8 +67,9 @@ def main(_): log_with="wandb", mixed_precision=config.mixed_precision, project_config=accelerator_config, - # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of - # _samples_ to accumulate across + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps, ) if accelerator.is_main_process: @@ -243,6 +244,7 @@ def main(_): logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}") + assert config.sample.batch_size >= config.train.batch_size assert config.sample.batch_size % config.train.batch_size == 0 assert samples_per_epoch % total_train_batch_size == 0 @@ -418,6 +420,7 @@ def main(_): noise_pred = pipeline.unet( sample["latents"][:, j], sample["timesteps"][:, j], embeds ).sample + # compute the log prob of next_latents given latents under the current model _, log_prob = ddim_step_with_logprob( pipeline.scheduler, noise_pred,