Fix gradient sync for lora
This commit is contained in:
		| @@ -134,9 +134,17 @@ def main(_): | ||||
|  | ||||
|             lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||
|         pipeline.unet.set_attn_processor(lora_attn_procs) | ||||
|         trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||
|  | ||||
|         # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in | ||||
|         # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a | ||||
|         # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure. | ||||
|         class _Wrapper(AttnProcsLayers): | ||||
|             def forward(self, *args, **kwargs): | ||||
|                 return pipeline.unet(*args, **kwargs) | ||||
|  | ||||
|         unet = _Wrapper(pipeline.unet.attn_processors) | ||||
|     else: | ||||
|         trainable_layers = pipeline.unet | ||||
|         unet = pipeline.unet | ||||
|  | ||||
|     # set up diffusers-friendly checkpoint saving with Accelerate | ||||
|  | ||||
| @@ -191,7 +199,7 @@ def main(_): | ||||
|         optimizer_cls = torch.optim.AdamW | ||||
|  | ||||
|     optimizer = optimizer_cls( | ||||
|         trainable_layers.parameters(), | ||||
|         unet.parameters(), | ||||
|         lr=config.train.learning_rate, | ||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||
|         weight_decay=config.train.adam_weight_decay, | ||||
| @@ -225,9 +233,10 @@ def main(_): | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | ||||
|     # more memory | ||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||
|     # autocast = accelerator.autocast | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||
|     unet, optimizer = accelerator.prepare(unet, optimizer) | ||||
|  | ||||
|     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a | ||||
|     # remote server running llava inference. | ||||
| @@ -335,7 +344,6 @@ def main(_): | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
|          | ||||
|         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||
|         with tempfile.TemporaryDirectory() as tmpdir: | ||||
|             for i, image in enumerate(images): | ||||
| @@ -351,7 +359,7 @@ def main(_): | ||||
|                 }, | ||||
|                 step=global_step, | ||||
|             ) | ||||
|              | ||||
|  | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|  | ||||
| @@ -425,10 +433,10 @@ def main(_): | ||||
|                     leave=False, | ||||
|                     disable=not accelerator.is_local_main_process, | ||||
|                 ): | ||||
|                     with accelerator.accumulate(pipeline.unet): | ||||
|                     with accelerator.accumulate(unet): | ||||
|                         with autocast(): | ||||
|                             if config.train.cfg: | ||||
|                                 noise_pred = pipeline.unet( | ||||
|                                 noise_pred = unet( | ||||
|                                     torch.cat([sample["latents"][:, j]] * 2), | ||||
|                                     torch.cat([sample["timesteps"][:, j]] * 2), | ||||
|                                     embeds, | ||||
| @@ -438,8 +446,10 @@ def main(_): | ||||
|                                     noise_pred_text - noise_pred_uncond | ||||
|                                 ) | ||||
|                             else: | ||||
|                                 noise_pred = pipeline.unet( | ||||
|                                     sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                                 noise_pred = unet( | ||||
|                                     sample["latents"][:, j], | ||||
|                                     sample["timesteps"][:, j], | ||||
|                                     embeds, | ||||
|                                 ).sample | ||||
|                             # compute the log prob of next_latents given latents under the current model | ||||
|                             _, log_prob = ddim_step_with_logprob( | ||||
| @@ -473,7 +483,7 @@ def main(_): | ||||
|                         # backward pass | ||||
|                         accelerator.backward(loss) | ||||
|                         if accelerator.sync_gradients: | ||||
|                             accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) | ||||
|                             accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm) | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user