Fix gradient sync for lora
This commit is contained in:
parent
d7a63516cb
commit
5955244f37
@ -134,9 +134,17 @@ def main(_):
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
|
||||
|
||||
# this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
|
||||
# this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
|
||||
# `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
|
||||
class _Wrapper(AttnProcsLayers):
|
||||
def forward(self, *args, **kwargs):
|
||||
return pipeline.unet(*args, **kwargs)
|
||||
|
||||
unet = _Wrapper(pipeline.unet.attn_processors)
|
||||
else:
|
||||
trainable_layers = pipeline.unet
|
||||
unet = pipeline.unet
|
||||
|
||||
# set up diffusers-friendly checkpoint saving with Accelerate
|
||||
|
||||
@ -191,7 +199,7 @@ def main(_):
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
trainable_layers.parameters(),
|
||||
unet.parameters(),
|
||||
lr=config.train.learning_rate,
|
||||
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
||||
weight_decay=config.train.adam_weight_decay,
|
||||
@ -225,9 +233,10 @@ def main(_):
|
||||
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||
# autocast = accelerator.autocast
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
||||
unet, optimizer = accelerator.prepare(unet, optimizer)
|
||||
|
||||
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
||||
# remote server running llava inference.
|
||||
@ -335,7 +344,6 @@ def main(_):
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
|
||||
|
||||
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for i, image in enumerate(images):
|
||||
@ -351,7 +359,7 @@ def main(_):
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
|
||||
# gather rewards across processes
|
||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||
|
||||
@ -425,10 +433,10 @@ def main(_):
|
||||
leave=False,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
):
|
||||
with accelerator.accumulate(pipeline.unet):
|
||||
with accelerator.accumulate(unet):
|
||||
with autocast():
|
||||
if config.train.cfg:
|
||||
noise_pred = pipeline.unet(
|
||||
noise_pred = unet(
|
||||
torch.cat([sample["latents"][:, j]] * 2),
|
||||
torch.cat([sample["timesteps"][:, j]] * 2),
|
||||
embeds,
|
||||
@ -438,8 +446,10 @@ def main(_):
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
else:
|
||||
noise_pred = pipeline.unet(
|
||||
sample["latents"][:, j], sample["timesteps"][:, j], embeds
|
||||
noise_pred = unet(
|
||||
sample["latents"][:, j],
|
||||
sample["timesteps"][:, j],
|
||||
embeds,
|
||||
).sample
|
||||
# compute the log prob of next_latents given latents under the current model
|
||||
_, log_prob = ddim_step_with_logprob(
|
||||
@ -473,7 +483,7 @@ def main(_):
|
||||
# backward pass
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user