Working non-lora training; other changes

This commit is contained in:
Kevin Black 2023-06-25 11:28:42 -07:00
parent c680890d5c
commit 269615a35e
4 changed files with 124 additions and 90 deletions

1
.gitignore vendored
View File

@ -303,3 +303,4 @@ tags
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
wandb/

View File

@ -10,6 +10,7 @@ def get_config():
config.num_epochs = 100
config.mixed_precision = "fp16"
config.allow_tf32 = True
config.use_lora = True
# pretrained model initialization
config.pretrained = pretrained = ml_collections.ConfigDict()
@ -20,7 +21,6 @@ def get_config():
config.train = train = ml_collections.ConfigDict()
train.batch_size = 1
train.use_8bit_adam = False
train.scale_lr = False
train.learning_rate = 1e-4
train.adam_beta1 = 0.9
train.adam_beta2 = 0.999
@ -35,7 +35,7 @@ def get_config():
# sampling
config.sample = sample = ml_collections.ConfigDict()
sample.num_steps = 30
sample.num_steps = 5
sample.eta = 1.0
sample.guidance_scale = 5.0
sample.batch_size = 1

View File

@ -4,16 +4,19 @@ from ddpo_pytorch.config import base
def get_config():
config = base.get_config()
config.mixed_precision = "bf16"
config.mixed_precision = "no"
config.allow_tf32 = True
config.use_lora = False
config.train.batch_size = 8
config.train.gradient_accumulation_steps = 4
config.train.batch_size = 4
config.train.gradient_accumulation_steps = 8
config.train.learning_rate = 1e-5
config.train.clip_range = 1.0
# sampling
config.sample.num_steps = 50
config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 4
config.sample.batch_size = 16
config.sample.num_batches_per_epoch = 2
config.per_prompt_stat_tracking = None

View File

@ -1,4 +1,6 @@
from collections import defaultdict
import contextlib
import os
from absl import app, flags, logging
from ml_collections import config_flags
from accelerate import Accelerator
@ -17,6 +19,8 @@ import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)
@ -46,9 +50,9 @@ def main(_):
# load scheduler, tokenizer and models.
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
# freeze parameters of models to save more memory
pipeline.unet.requires_grad_(False)
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
pipeline.unet.requires_grad_(not config.use_lora)
# disable safety checker
pipeline.safety_checker = None
# make the progress bar nicer
@ -56,40 +60,47 @@ def main(_):
position=1,
disable=not accelerator.is_local_main_process,
leave=False,
desc="Timestep",
dynamic_ncols=True,
)
# switch to DDIM scheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
inference_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
inference_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
inference_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
pipeline.unet.to(accelerator.device, dtype=weight_dtype)
pipeline.vae.to(accelerator.device, dtype=weight_dtype)
pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype)
# Move unet, vae and text_encoder to device and cast to inference_dtype
pipeline.vae.to(accelerator.device, dtype=inference_dtype)
pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
if config.use_lora:
pipeline.unet.to(accelerator.device, dtype=inference_dtype)
# Set correct lora layers
lora_attn_procs = {}
for name in pipeline.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipeline.unet.config.block_out_channels[block_id]
if config.use_lora:
# Set correct lora layers
lora_attn_procs = {}
for name in pipeline.unet.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipeline.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
pipeline.unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(pipeline.unet.attn_processors)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
pipeline.unet.set_attn_processor(lora_attn_procs)
trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors)
else:
trainable_layers = pipeline.unet
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@ -110,7 +121,7 @@ def main(_):
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers.parameters(),
trainable_layers.parameters(),
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
@ -121,8 +132,31 @@ def main(_):
prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn)
reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)()
# generate negative prompt embeddings
neg_prompt_embed = pipeline.text_encoder(
pipeline.tokenizer(
[""],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=pipeline.tokenizer.model_max_length,
).input_ids.to(accelerator.device)
)[0]
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
# initialize stat tracker
if config.per_prompt_stat_tracking:
stat_tracker = PerPromptStatTracker(
config.per_prompt_stat_tracking.buffer_size,
config.per_prompt_stat_tracking.min_count,
)
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# Prepare everything with our `accelerator`.
lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer)
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
# Train!
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
@ -144,27 +178,10 @@ def main(_):
assert config.sample.batch_size % config.train.batch_size == 0
assert samples_per_epoch % total_train_batch_size == 0
neg_prompt_embed = pipeline.text_encoder(
pipeline.tokenizer(
[""],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=pipeline.tokenizer.model_max_length,
).input_ids.to(accelerator.device)
)[0]
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
if config.per_prompt_stat_tracking:
stat_tracker = PerPromptStatTracker(
config.per_prompt_stat_tracking.buffer_size,
config.per_prompt_stat_tracking.min_count,
)
global_step = 0
for epoch in range(config.num_epochs):
#################### SAMPLING ####################
pipeline.unet.eval()
samples = []
prompts = []
for i in tqdm(
@ -189,17 +206,16 @@ def main(_):
prompt_embeds = pipeline.text_encoder(prompt_ids)[0]
# sample
pipeline.unet.eval()
pipeline.vae.eval()
images, _, latents, log_probs = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=config.sample.num_steps,
guidance_scale=config.sample.guidance_scale,
eta=config.sample.eta,
output_type="pt",
)
with autocast():
images, _, latents, log_probs = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=config.sample.num_steps,
guidance_scale=config.sample.guidance_scale,
eta=config.sample.eta,
output_type="pt",
)
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64)
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
@ -226,14 +242,26 @@ def main(_):
# gather rewards across processes
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
# log sample-related stuff
accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step)
# log rewards and images
accelerator.log(
{"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]},
{"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()},
step=global_step,
)
# from PIL import Image
# Image.fromarray((images[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save(f"test.png")
# this is a hack to force wandb to log the images as JPEGs instead of PNGs
with tempfile.TemporaryDirectory() as tmpdir:
for i, image in enumerate(images):
pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
pil = pil.resize((256, 256))
pil.save(os.path.join(tmpdir, f"{i}.jpg"))
accelerator.log(
{
"images": [
wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt)
for i, prompt in enumerate(prompts)
],
},
step=global_step,
)
# per-prompt mean/std tracking
if config.per_prompt_stat_tracking:
@ -271,6 +299,7 @@ def main(_):
samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
# train
pipeline.unet.train()
for i, sample in tqdm(
list(enumerate(samples_batched)),
desc=f"Epoch {epoch}.{inner_epoch}: training",
@ -286,34 +315,35 @@ def main(_):
info = defaultdict(list)
for j in tqdm(
range(num_timesteps),
desc=f"Timestep",
desc="Timestep",
position=1,
leave=False,
disable=not accelerator.is_local_main_process,
):
with accelerator.accumulate(pipeline.unet):
if config.train.cfg:
noise_pred = pipeline.unet(
torch.cat([sample["latents"][:, j]] * 2),
torch.cat([sample["timesteps"][:, j]] * 2),
embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + config.sample.guidance_scale * (
noise_pred_text - noise_pred_uncond
with autocast():
if config.train.cfg:
noise_pred = pipeline.unet(
torch.cat([sample["latents"][:, j]] * 2),
torch.cat([sample["timesteps"][:, j]] * 2),
embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + config.sample.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
noise_pred = pipeline.unet(
sample["latents"][:, j], sample["timesteps"][:, j], embeds
).sample
_, log_prob = ddim_step_with_logprob(
pipeline.scheduler,
noise_pred,
sample["timesteps"][:, j],
sample["latents"][:, j],
eta=config.sample.eta,
prev_sample=sample["next_latents"][:, j],
)
else:
noise_pred = pipeline.unet(
sample["latents"][:, j], sample["timesteps"][:, j], embeds
).sample
_, log_prob = ddim_step_with_logprob(
pipeline.scheduler,
noise_pred,
sample["timesteps"][:, j],
sample["latents"][:, j],
eta=config.sample.eta,
prev_sample=sample["next_latents"][:, j],
)
# ppo logic
advantages = torch.clamp(
@ -337,7 +367,7 @@ def main(_):
# backward pass
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm)
accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm)
optimizer.step()
optimizer.zero_grad()