From 269615a35e0dcd5058dd7a0467ded1b6d88671e2 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Sun, 25 Jun 2023 11:28:42 -0700 Subject: [PATCH] Working non-lora training; other changes --- .gitignore | 1 + ddpo_pytorch/config/base.py | 4 +- ddpo_pytorch/config/dgx.py | 13 ++- scripts/train.py | 196 +++++++++++++++++++++--------------- 4 files changed, 124 insertions(+), 90 deletions(-) diff --git a/.gitignore b/.gitignore index c4f2c1c..30b5f32 100644 --- a/.gitignore +++ b/.gitignore @@ -303,3 +303,4 @@ tags # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim +wandb/ \ No newline at end of file diff --git a/ddpo_pytorch/config/base.py b/ddpo_pytorch/config/base.py index b70042f..ca0878f 100644 --- a/ddpo_pytorch/config/base.py +++ b/ddpo_pytorch/config/base.py @@ -10,6 +10,7 @@ def get_config(): config.num_epochs = 100 config.mixed_precision = "fp16" config.allow_tf32 = True + config.use_lora = True # pretrained model initialization config.pretrained = pretrained = ml_collections.ConfigDict() @@ -20,7 +21,6 @@ def get_config(): config.train = train = ml_collections.ConfigDict() train.batch_size = 1 train.use_8bit_adam = False - train.scale_lr = False train.learning_rate = 1e-4 train.adam_beta1 = 0.9 train.adam_beta2 = 0.999 @@ -35,7 +35,7 @@ def get_config(): # sampling config.sample = sample = ml_collections.ConfigDict() - sample.num_steps = 30 + sample.num_steps = 5 sample.eta = 1.0 sample.guidance_scale = 5.0 sample.batch_size = 1 diff --git a/ddpo_pytorch/config/dgx.py b/ddpo_pytorch/config/dgx.py index 80bc342..cd387ee 100644 --- a/ddpo_pytorch/config/dgx.py +++ b/ddpo_pytorch/config/dgx.py @@ -4,16 +4,19 @@ from ddpo_pytorch.config import base def get_config(): config = base.get_config() - config.mixed_precision = "bf16" + config.mixed_precision = "no" config.allow_tf32 = True + config.use_lora = False - config.train.batch_size = 8 - config.train.gradient_accumulation_steps = 4 + config.train.batch_size = 4 + config.train.gradient_accumulation_steps = 8 + config.train.learning_rate = 1e-5 + config.train.clip_range = 1.0 # sampling config.sample.num_steps = 50 - config.sample.batch_size = 8 - config.sample.num_batches_per_epoch = 4 + config.sample.batch_size = 16 + config.sample.num_batches_per_epoch = 2 config.per_prompt_stat_tracking = None diff --git a/scripts/train.py b/scripts/train.py index 6cba4a2..970153e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,4 +1,6 @@ from collections import defaultdict +import contextlib +import os from absl import app, flags, logging from ml_collections import config_flags from accelerate import Accelerator @@ -17,6 +19,8 @@ import torch import wandb from functools import partial import tqdm +import tempfile +from PIL import Image tqdm = partial(tqdm.tqdm, dynamic_ncols=True) @@ -46,9 +50,9 @@ def main(_): # load scheduler, tokenizer and models. pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) # freeze parameters of models to save more memory - pipeline.unet.requires_grad_(False) pipeline.vae.requires_grad_(False) pipeline.text_encoder.requires_grad_(False) + pipeline.unet.requires_grad_(not config.use_lora) # disable safety checker pipeline.safety_checker = None # make the progress bar nicer @@ -56,40 +60,47 @@ def main(_): position=1, disable=not accelerator.is_local_main_process, leave=False, + desc="Timestep", + dynamic_ncols=True, ) # switch to DDIM scheduler pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 + inference_dtype = torch.float32 if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 + inference_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + inference_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype - pipeline.unet.to(accelerator.device, dtype=weight_dtype) - pipeline.vae.to(accelerator.device, dtype=weight_dtype) - pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) + # Move unet, vae and text_encoder to device and cast to inference_dtype + pipeline.vae.to(accelerator.device, dtype=inference_dtype) + pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype) + if config.use_lora: + pipeline.unet.to(accelerator.device, dtype=inference_dtype) - # Set correct lora layers - lora_attn_procs = {} - for name in pipeline.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = pipeline.unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = pipeline.unet.config.block_out_channels[block_id] + if config.use_lora: + # Set correct lora layers + lora_attn_procs = {} + for name in pipeline.unet.attn_processors.keys(): + cross_attention_dim = ( + None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim + ) + if name.startswith("mid_block"): + hidden_size = pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipeline.unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) - - pipeline.unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + pipeline.unet.set_attn_processor(lora_attn_procs) + trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) + else: + trainable_layers = pipeline.unet # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -110,7 +121,7 @@ def main(_): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - lora_layers.parameters(), + trainable_layers.parameters(), lr=config.train.learning_rate, betas=(config.train.adam_beta1, config.train.adam_beta2), weight_decay=config.train.adam_weight_decay, @@ -121,8 +132,31 @@ def main(_): prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() + # generate negative prompt embeddings + neg_prompt_embed = pipeline.text_encoder( + pipeline.tokenizer( + [""], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + )[0] + sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) + train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) + + # initialize stat tracker + if config.per_prompt_stat_tracking: + stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking.buffer_size, + config.per_prompt_stat_tracking.min_count, + ) + + # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory + autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast + # Prepare everything with our `accelerator`. - lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) + trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) # Train! samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch @@ -144,27 +178,10 @@ def main(_): assert config.sample.batch_size % config.train.batch_size == 0 assert samples_per_epoch % total_train_batch_size == 0 - neg_prompt_embed = pipeline.text_encoder( - pipeline.tokenizer( - [""], - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=pipeline.tokenizer.model_max_length, - ).input_ids.to(accelerator.device) - )[0] - sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) - train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) - - if config.per_prompt_stat_tracking: - stat_tracker = PerPromptStatTracker( - config.per_prompt_stat_tracking.buffer_size, - config.per_prompt_stat_tracking.min_count, - ) - global_step = 0 for epoch in range(config.num_epochs): #################### SAMPLING #################### + pipeline.unet.eval() samples = [] prompts = [] for i in tqdm( @@ -189,17 +206,16 @@ def main(_): prompt_embeds = pipeline.text_encoder(prompt_ids)[0] # sample - pipeline.unet.eval() - pipeline.vae.eval() - images, _, latents, log_probs = pipeline_with_logprob( - pipeline, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=sample_neg_prompt_embeds, - num_inference_steps=config.sample.num_steps, - guidance_scale=config.sample.guidance_scale, - eta=config.sample.eta, - output_type="pt", - ) + with autocast(): + images, _, latents, log_probs = pipeline_with_logprob( + pipeline, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=config.sample.num_steps, + guidance_scale=config.sample.guidance_scale, + eta=config.sample.eta, + output_type="pt", + ) latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64) log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) @@ -226,14 +242,26 @@ def main(_): # gather rewards across processes rewards = accelerator.gather(samples["rewards"]).cpu().numpy() - # log sample-related stuff - accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step) + # log rewards and images accelerator.log( - {"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]}, + {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, step=global_step, ) - # from PIL import Image - # Image.fromarray((images[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save(f"test.png") + # this is a hack to force wandb to log the images as JPEGs instead of PNGs + with tempfile.TemporaryDirectory() as tmpdir: + for i, image in enumerate(images): + pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)) + pil = pil.resize((256, 256)) + pil.save(os.path.join(tmpdir, f"{i}.jpg")) + accelerator.log( + { + "images": [ + wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt) + for i, prompt in enumerate(prompts) + ], + }, + step=global_step, + ) # per-prompt mean/std tracking if config.per_prompt_stat_tracking: @@ -271,6 +299,7 @@ def main(_): samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] # train + pipeline.unet.train() for i, sample in tqdm( list(enumerate(samples_batched)), desc=f"Epoch {epoch}.{inner_epoch}: training", @@ -286,34 +315,35 @@ def main(_): info = defaultdict(list) for j in tqdm( range(num_timesteps), - desc=f"Timestep", + desc="Timestep", position=1, leave=False, disable=not accelerator.is_local_main_process, ): with accelerator.accumulate(pipeline.unet): - if config.train.cfg: - noise_pred = pipeline.unet( - torch.cat([sample["latents"][:, j]] * 2), - torch.cat([sample["timesteps"][:, j]] * 2), - embeds, - ).sample - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( - noise_pred_text - noise_pred_uncond + with autocast(): + if config.train.cfg: + noise_pred = pipeline.unet( + torch.cat([sample["latents"][:, j]] * 2), + torch.cat([sample["timesteps"][:, j]] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = pipeline.unet( + sample["latents"][:, j], sample["timesteps"][:, j], embeds + ).sample + _, log_prob = ddim_step_with_logprob( + pipeline.scheduler, + noise_pred, + sample["timesteps"][:, j], + sample["latents"][:, j], + eta=config.sample.eta, + prev_sample=sample["next_latents"][:, j], ) - else: - noise_pred = pipeline.unet( - sample["latents"][:, j], sample["timesteps"][:, j], embeds - ).sample - _, log_prob = ddim_step_with_logprob( - pipeline.scheduler, - noise_pred, - sample["timesteps"][:, j], - sample["latents"][:, j], - eta=config.sample.eta, - prev_sample=sample["next_latents"][:, j], - ) # ppo logic advantages = torch.clamp( @@ -337,7 +367,7 @@ def main(_): # backward pass accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) + accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) optimizer.step() optimizer.zero_grad()