From 92fc030123cc472e490f9aa19c959ac8550f8489 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Fri, 23 Jun 2023 21:08:32 -0700 Subject: [PATCH] Continue implementation --- config/base.py | 12 +++--- .../diffusers_patch/ddim_with_logprob.py | 33 ++++++++++++----- scripts/train.py | 37 ++++++++++++++----- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/config/base.py b/config/base.py index e38e6bb..8a3118b 100644 --- a/config/base.py +++ b/config/base.py @@ -25,9 +25,9 @@ def get_config(): train.learning_rate = 1e-4 train.adam_beta1 = 0.9 train.adam_beta2 = 0.999 - train.adam_weight_decay = 1e-2 + train.adam_weight_decay = 1e-4 train.adam_epsilon = 1e-8 - train.gradient_accumulation_steps = 1 + train.gradient_accumulation_steps = 32 train.max_grad_norm = 1.0 train.num_inner_epochs = 1 train.cfg = True @@ -36,11 +36,11 @@ def get_config(): # sampling config.sample = sample = ml_collections.ConfigDict() - sample.num_steps = 5 + sample.num_steps = 30 sample.eta = 1.0 sample.guidance_scale = 5.0 - sample.batch_size = 1 - sample.num_batches_per_epoch = 4 + sample.batch_size = 4 + sample.num_batches_per_epoch = 8 # prompting config.prompt_fn = "imagenet_animals" @@ -50,7 +50,7 @@ def get_config(): config.reward_fn = "jpeg_compressibility" config.per_prompt_stat_tracking = ml_collections.ConfigDict() - config.per_prompt_stat_tracking.buffer_size = 128 + config.per_prompt_stat_tracking.buffer_size = 64 config.per_prompt_stat_tracking.min_count = 16 return config \ No newline at end of file diff --git a/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py index 43b2fe8..be5f421 100644 --- a/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py +++ b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py @@ -1,6 +1,9 @@ # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py # with the following modifications: -# - +# - It computes and returns the log prob of `prev_sample` given the UNet prediction. +# - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided, +# it uses it to compute the log prob. +# - Timesteps can be a batched torch.Tensor. from typing import Optional, Tuple, Union @@ -11,6 +14,19 @@ from diffusers.utils import randn_tensor from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def ddim_step_with_logprob( self: DDIMScheduler, model_output: torch.FloatTensor, @@ -66,16 +82,13 @@ def ddim_step_with_logprob( # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) # 2. compute alphas, betas - self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) - self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) - alpha_prod_t = self.alphas_cumprod.gather(0, timestep) - alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) - print(timestep) - print(alpha_prod_t) - print(alpha_prod_t_prev) - print(prev_timestep) + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod + ).to(timestep.device) beta_prod_t = 1 - alpha_prod_t @@ -106,7 +119,7 @@ def ddim_step_with_logprob( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) + variance = _get_variance(self, timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: diff --git a/scripts/train.py b/scripts/train.py index c123ba7..86f66ea 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,8 +12,12 @@ from ddpo_pytorch.stat_tracking import PerPromptStatTracker from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob import torch +import wandb +from functools import partial import tqdm +tqdm = partial(tqdm.tqdm, dynamic_ncols=True) + FLAGS = flags.FLAGS config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") @@ -25,7 +29,7 @@ def main(_): # basic Accelerate and logging setup config = FLAGS.config accelerator = Accelerator( - log_with="all", + log_with="wandb", mixed_precision=config.mixed_precision, project_dir=config.logdir, ) @@ -163,11 +167,12 @@ def main(_): config.per_prompt_stat_tracking.min_count, ) + global_step = 0 for epoch in range(config.num_epochs): #################### SAMPLING #################### samples = [] prompts = [] - for i in tqdm.tqdm( + for i in tqdm( range(config.sample.num_batches_per_epoch), desc=f"Epoch {epoch}: sampling", disable=not accelerator.is_local_main_process, @@ -216,7 +221,7 @@ def main(_): "latents": latents[:, :-1], # each entry is the latent before timestep t "next_latents": latents[:, 1:], # each entry is the latent after timestep t "log_probs": log_probs, - "rewards": torch.as_tensor(rewards), + "rewards": torch.as_tensor(rewards, device=accelerator.device), } ) @@ -226,6 +231,13 @@ def main(_): # gather rewards across processes rewards = accelerator.gather(samples["rewards"]).cpu().numpy() + # log sample-related stuff + accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step) + accelerator.log( + {"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]}, + step=global_step, + ) + # per-prompt mean/std tracking if config.per_prompt_stat_tracking: # gather the prompts across processes @@ -268,10 +280,11 @@ def main(_): samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] # train - for i, sample in tqdm.tqdm( + for i, sample in tqdm( list(enumerate(samples_batched)), - desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", + desc=f"Epoch {epoch}.{inner_epoch}: training", position=0, + disable=not accelerator.is_local_main_process, ): if config.train.cfg: # concat negative prompts to sample prompts to avoid two forward passes @@ -279,11 +292,12 @@ def main(_): else: embeds = sample["prompt_embeds"] - for j in tqdm.trange( - num_timesteps, + for j in tqdm( + range(num_timesteps), desc=f"Timestep", position=1, leave=False, + disable=not accelerator.is_local_main_process, ): with accelerator.accumulate(pipeline.unet): if config.train.cfg: @@ -311,7 +325,7 @@ def main(_): # ppo logic advantages = torch.clamp( - sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max + sample["advantages"], -config.train.adv_clip_max, config.train.adv_clip_max ) ratio = torch.exp(log_prob - sample["log_probs"][:, j]) unclipped_loss = -advantages * ratio @@ -326,9 +340,14 @@ def main(_): # estimator, but most existing code uses this so... # http://joschu.net/blog/kl-approx.html info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) - info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) + info["clipfrac"] = torch.mean((torch.abs(ratio - 1.0) > config.train.clip_range).float()) info["loss"] = loss + # log training-related stuff + info.update({"epoch": epoch, "inner_epoch": inner_epoch, "timestep": j}) + accelerator.log(info, step=global_step) + global_step += 1 + # backward pass accelerator.backward(loss) if accelerator.sync_gradients: