Fix gradient sync for lora

This commit is contained in:
Kevin Black 2023-08-22 16:18:49 -07:00
parent d7a63516cb
commit 5955244f37

View File

@ -134,9 +134,17 @@ def main(_):
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
pipeline.unet.set_attn_processor(lora_attn_procs)
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
# this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
# this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
# `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
class _Wrapper(AttnProcsLayers):
def forward(self, *args, **kwargs):
return pipeline.unet(*args, **kwargs)
unet = _Wrapper(pipeline.unet.attn_processors)
else:
trainable_layers = pipeline.unet
unet = pipeline.unet
# set up diffusers-friendly checkpoint saving with Accelerate
@ -191,7 +199,7 @@ def main(_):
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
trainable_layers.parameters(),
unet.parameters(),
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
@ -225,9 +233,10 @@ def main(_):
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# autocast = accelerator.autocast
# Prepare everything with our `accelerator`.
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
unet, optimizer = accelerator.prepare(unet, optimizer)
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
# remote server running llava inference.
@ -335,7 +344,6 @@ def main(_):
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
with tempfile.TemporaryDirectory() as tmpdir:
for i, image in enumerate(images):
@ -351,7 +359,7 @@ def main(_):
},
step=global_step,
)
# gather rewards across processes
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
@ -425,10 +433,10 @@ def main(_):
leave=False,
disable=not accelerator.is_local_main_process,
):
with accelerator.accumulate(pipeline.unet):
with accelerator.accumulate(unet):
with autocast():
if config.train.cfg:
noise_pred = pipeline.unet(
noise_pred = unet(
torch.cat([sample["latents"][:, j]] * 2),
torch.cat([sample["timesteps"][:, j]] * 2),
embeds,
@ -438,8 +446,10 @@ def main(_):
noise_pred_text - noise_pred_uncond
)
else:
noise_pred = pipeline.unet(
sample["latents"][:, j], sample["timesteps"][:, j], embeds
noise_pred = unet(
sample["latents"][:, j],
sample["timesteps"][:, j],
embeds,
).sample
# compute the log prob of next_latents given latents under the current model
_, log_prob = ddim_step_with_logprob(
@ -473,7 +483,7 @@ def main(_):
# backward pass
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm)
optimizer.step()
optimizer.zero_grad()