From 5955244f37d74295fa9959c3bef999b1ef1cf7b8 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Tue, 22 Aug 2023 16:18:49 -0700 Subject: [PATCH] Fix gradient sync for lora --- scripts/train.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index ae2b71e..e2bc11f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -134,9 +134,17 @@ def main(_): lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) pipeline.unet.set_attn_processor(lora_attn_procs) - trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) + + # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in + # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a + # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure. + class _Wrapper(AttnProcsLayers): + def forward(self, *args, **kwargs): + return pipeline.unet(*args, **kwargs) + + unet = _Wrapper(pipeline.unet.attn_processors) else: - trainable_layers = pipeline.unet + unet = pipeline.unet # set up diffusers-friendly checkpoint saving with Accelerate @@ -191,7 +199,7 @@ def main(_): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - trainable_layers.parameters(), + unet.parameters(), lr=config.train.learning_rate, betas=(config.train.adam_beta1, config.train.adam_beta2), weight_decay=config.train.adam_weight_decay, @@ -225,9 +233,10 @@ def main(_): # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses # more memory autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast + # autocast = accelerator.autocast # Prepare everything with our `accelerator`. - trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) + unet, optimizer = accelerator.prepare(unet, optimizer) # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a # remote server running llava inference. @@ -335,7 +344,6 @@ def main(_): # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} - # this is a hack to force wandb to log the images as JPEGs instead of PNGs with tempfile.TemporaryDirectory() as tmpdir: for i, image in enumerate(images): @@ -351,7 +359,7 @@ def main(_): }, step=global_step, ) - + # gather rewards across processes rewards = accelerator.gather(samples["rewards"]).cpu().numpy() @@ -425,10 +433,10 @@ def main(_): leave=False, disable=not accelerator.is_local_main_process, ): - with accelerator.accumulate(pipeline.unet): + with accelerator.accumulate(unet): with autocast(): if config.train.cfg: - noise_pred = pipeline.unet( + noise_pred = unet( torch.cat([sample["latents"][:, j]] * 2), torch.cat([sample["timesteps"][:, j]] * 2), embeds, @@ -438,8 +446,10 @@ def main(_): noise_pred_text - noise_pred_uncond ) else: - noise_pred = pipeline.unet( - sample["latents"][:, j], sample["timesteps"][:, j], embeds + noise_pred = unet( + sample["latents"][:, j], + sample["timesteps"][:, j], + embeds, ).sample # compute the log prob of next_latents given latents under the current model _, log_prob = ddim_step_with_logprob( @@ -473,7 +483,7 @@ def main(_): # backward pass accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) + accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm) optimizer.step() optimizer.zero_grad()