Fix gradient sync for lora
This commit is contained in:
parent
d7a63516cb
commit
5955244f37
@ -134,9 +134,17 @@ def main(_):
|
|||||||
|
|
||||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||||
pipeline.unet.set_attn_processor(lora_attn_procs)
|
pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||||
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
|
|
||||||
|
# this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
|
||||||
|
# this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
|
||||||
|
# `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
|
||||||
|
class _Wrapper(AttnProcsLayers):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return pipeline.unet(*args, **kwargs)
|
||||||
|
|
||||||
|
unet = _Wrapper(pipeline.unet.attn_processors)
|
||||||
else:
|
else:
|
||||||
trainable_layers = pipeline.unet
|
unet = pipeline.unet
|
||||||
|
|
||||||
# set up diffusers-friendly checkpoint saving with Accelerate
|
# set up diffusers-friendly checkpoint saving with Accelerate
|
||||||
|
|
||||||
@ -191,7 +199,7 @@ def main(_):
|
|||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
trainable_layers.parameters(),
|
unet.parameters(),
|
||||||
lr=config.train.learning_rate,
|
lr=config.train.learning_rate,
|
||||||
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
betas=(config.train.adam_beta1, config.train.adam_beta2),
|
||||||
weight_decay=config.train.adam_weight_decay,
|
weight_decay=config.train.adam_weight_decay,
|
||||||
@ -225,9 +233,10 @@ def main(_):
|
|||||||
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
# for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||||
# more memory
|
# more memory
|
||||||
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
|
||||||
|
# autocast = accelerator.autocast
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
# Prepare everything with our `accelerator`.
|
||||||
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
unet, optimizer = accelerator.prepare(unet, optimizer)
|
||||||
|
|
||||||
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
||||||
# remote server running llava inference.
|
# remote server running llava inference.
|
||||||
@ -335,7 +344,6 @@ def main(_):
|
|||||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
@ -351,7 +359,7 @@ def main(_):
|
|||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gather rewards across processes
|
# gather rewards across processes
|
||||||
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
|
||||||
|
|
||||||
@ -425,10 +433,10 @@ def main(_):
|
|||||||
leave=False,
|
leave=False,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
):
|
):
|
||||||
with accelerator.accumulate(pipeline.unet):
|
with accelerator.accumulate(unet):
|
||||||
with autocast():
|
with autocast():
|
||||||
if config.train.cfg:
|
if config.train.cfg:
|
||||||
noise_pred = pipeline.unet(
|
noise_pred = unet(
|
||||||
torch.cat([sample["latents"][:, j]] * 2),
|
torch.cat([sample["latents"][:, j]] * 2),
|
||||||
torch.cat([sample["timesteps"][:, j]] * 2),
|
torch.cat([sample["timesteps"][:, j]] * 2),
|
||||||
embeds,
|
embeds,
|
||||||
@ -438,8 +446,10 @@ def main(_):
|
|||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
noise_pred = pipeline.unet(
|
noise_pred = unet(
|
||||||
sample["latents"][:, j], sample["timesteps"][:, j], embeds
|
sample["latents"][:, j],
|
||||||
|
sample["timesteps"][:, j],
|
||||||
|
embeds,
|
||||||
).sample
|
).sample
|
||||||
# compute the log prob of next_latents given latents under the current model
|
# compute the log prob of next_latents given latents under the current model
|
||||||
_, log_prob = ddim_step_with_logprob(
|
_, log_prob = ddim_step_with_logprob(
|
||||||
@ -473,7 +483,7 @@ def main(_):
|
|||||||
# backward pass
|
# backward pass
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
|
accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user