Fix gradient sync for lora
This commit is contained in:
		| @@ -134,9 +134,17 @@ def main(_): | |||||||
|  |  | ||||||
|             lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) |             lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||||
|         pipeline.unet.set_attn_processor(lora_attn_procs) |         pipeline.unet.set_attn_processor(lora_attn_procs) | ||||||
|         trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) |  | ||||||
|  |         # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in | ||||||
|  |         # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a | ||||||
|  |         # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure. | ||||||
|  |         class _Wrapper(AttnProcsLayers): | ||||||
|  |             def forward(self, *args, **kwargs): | ||||||
|  |                 return pipeline.unet(*args, **kwargs) | ||||||
|  |  | ||||||
|  |         unet = _Wrapper(pipeline.unet.attn_processors) | ||||||
|     else: |     else: | ||||||
|         trainable_layers = pipeline.unet |         unet = pipeline.unet | ||||||
|  |  | ||||||
|     # set up diffusers-friendly checkpoint saving with Accelerate |     # set up diffusers-friendly checkpoint saving with Accelerate | ||||||
|  |  | ||||||
| @@ -191,7 +199,7 @@ def main(_): | |||||||
|         optimizer_cls = torch.optim.AdamW |         optimizer_cls = torch.optim.AdamW | ||||||
|  |  | ||||||
|     optimizer = optimizer_cls( |     optimizer = optimizer_cls( | ||||||
|         trainable_layers.parameters(), |         unet.parameters(), | ||||||
|         lr=config.train.learning_rate, |         lr=config.train.learning_rate, | ||||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), |         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||||
|         weight_decay=config.train.adam_weight_decay, |         weight_decay=config.train.adam_weight_decay, | ||||||
| @@ -225,9 +233,10 @@ def main(_): | |||||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses |     # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | ||||||
|     # more memory |     # more memory | ||||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast |     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||||
|  |     # autocast = accelerator.autocast | ||||||
|  |  | ||||||
|     # Prepare everything with our `accelerator`. |     # Prepare everything with our `accelerator`. | ||||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) |     unet, optimizer = accelerator.prepare(unet, optimizer) | ||||||
|  |  | ||||||
|     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a |     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a | ||||||
|     # remote server running llava inference. |     # remote server running llava inference. | ||||||
| @@ -335,7 +344,6 @@ def main(_): | |||||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) |         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} |         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||||
|  |  | ||||||
|          |  | ||||||
|         # this is a hack to force wandb to log the images as JPEGs instead of PNGs |         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||||
|         with tempfile.TemporaryDirectory() as tmpdir: |         with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|             for i, image in enumerate(images): |             for i, image in enumerate(images): | ||||||
| @@ -351,7 +359,7 @@ def main(_): | |||||||
|                 }, |                 }, | ||||||
|                 step=global_step, |                 step=global_step, | ||||||
|             ) |             ) | ||||||
|              |  | ||||||
|         # gather rewards across processes |         # gather rewards across processes | ||||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() |         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||||
|  |  | ||||||
| @@ -425,10 +433,10 @@ def main(_): | |||||||
|                     leave=False, |                     leave=False, | ||||||
|                     disable=not accelerator.is_local_main_process, |                     disable=not accelerator.is_local_main_process, | ||||||
|                 ): |                 ): | ||||||
|                     with accelerator.accumulate(pipeline.unet): |                     with accelerator.accumulate(unet): | ||||||
|                         with autocast(): |                         with autocast(): | ||||||
|                             if config.train.cfg: |                             if config.train.cfg: | ||||||
|                                 noise_pred = pipeline.unet( |                                 noise_pred = unet( | ||||||
|                                     torch.cat([sample["latents"][:, j]] * 2), |                                     torch.cat([sample["latents"][:, j]] * 2), | ||||||
|                                     torch.cat([sample["timesteps"][:, j]] * 2), |                                     torch.cat([sample["timesteps"][:, j]] * 2), | ||||||
|                                     embeds, |                                     embeds, | ||||||
| @@ -438,8 +446,10 @@ def main(_): | |||||||
|                                     noise_pred_text - noise_pred_uncond |                                     noise_pred_text - noise_pred_uncond | ||||||
|                                 ) |                                 ) | ||||||
|                             else: |                             else: | ||||||
|                                 noise_pred = pipeline.unet( |                                 noise_pred = unet( | ||||||
|                                     sample["latents"][:, j], sample["timesteps"][:, j], embeds |                                     sample["latents"][:, j], | ||||||
|  |                                     sample["timesteps"][:, j], | ||||||
|  |                                     embeds, | ||||||
|                                 ).sample |                                 ).sample | ||||||
|                             # compute the log prob of next_latents given latents under the current model |                             # compute the log prob of next_latents given latents under the current model | ||||||
|                             _, log_prob = ddim_step_with_logprob( |                             _, log_prob = ddim_step_with_logprob( | ||||||
| @@ -473,7 +483,7 @@ def main(_): | |||||||
|                         # backward pass |                         # backward pass | ||||||
|                         accelerator.backward(loss) |                         accelerator.backward(loss) | ||||||
|                         if accelerator.sync_gradients: |                         if accelerator.sync_gradients: | ||||||
|                             accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) |                             accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm) | ||||||
|                         optimizer.step() |                         optimizer.step() | ||||||
|                         optimizer.zero_grad() |                         optimizer.zero_grad() | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user