Working non-lora training; other changes
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -303,3 +303,4 @@ tags | |||||||
|  |  | ||||||
| # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||||
|  |  | ||||||
|  | wandb/ | ||||||
| @@ -10,6 +10,7 @@ def get_config(): | |||||||
|     config.num_epochs = 100 |     config.num_epochs = 100 | ||||||
|     config.mixed_precision = "fp16" |     config.mixed_precision = "fp16" | ||||||
|     config.allow_tf32 = True |     config.allow_tf32 = True | ||||||
|  |     config.use_lora = True | ||||||
|  |  | ||||||
|     # pretrained model initialization |     # pretrained model initialization | ||||||
|     config.pretrained = pretrained = ml_collections.ConfigDict() |     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||||
| @@ -20,7 +21,6 @@ def get_config(): | |||||||
|     config.train = train = ml_collections.ConfigDict() |     config.train = train = ml_collections.ConfigDict() | ||||||
|     train.batch_size = 1 |     train.batch_size = 1 | ||||||
|     train.use_8bit_adam = False |     train.use_8bit_adam = False | ||||||
|     train.scale_lr = False |  | ||||||
|     train.learning_rate = 1e-4 |     train.learning_rate = 1e-4 | ||||||
|     train.adam_beta1 = 0.9 |     train.adam_beta1 = 0.9 | ||||||
|     train.adam_beta2 = 0.999 |     train.adam_beta2 = 0.999 | ||||||
| @@ -35,7 +35,7 @@ def get_config(): | |||||||
|  |  | ||||||
|     # sampling |     # sampling | ||||||
|     config.sample = sample = ml_collections.ConfigDict() |     config.sample = sample = ml_collections.ConfigDict() | ||||||
|     sample.num_steps = 30 |     sample.num_steps = 5 | ||||||
|     sample.eta = 1.0 |     sample.eta = 1.0 | ||||||
|     sample.guidance_scale = 5.0 |     sample.guidance_scale = 5.0 | ||||||
|     sample.batch_size = 1 |     sample.batch_size = 1 | ||||||
|   | |||||||
| @@ -4,16 +4,19 @@ from ddpo_pytorch.config import base | |||||||
| def get_config(): | def get_config(): | ||||||
|     config = base.get_config() |     config = base.get_config() | ||||||
|  |  | ||||||
|     config.mixed_precision = "bf16" |     config.mixed_precision = "no" | ||||||
|     config.allow_tf32 = True |     config.allow_tf32 = True | ||||||
|  |     config.use_lora = False | ||||||
|  |  | ||||||
|     config.train.batch_size = 8 |     config.train.batch_size = 4 | ||||||
|     config.train.gradient_accumulation_steps = 4 |     config.train.gradient_accumulation_steps = 8 | ||||||
|  |     config.train.learning_rate = 1e-5 | ||||||
|  |     config.train.clip_range = 1.0 | ||||||
|  |  | ||||||
|     # sampling |     # sampling | ||||||
|     config.sample.num_steps = 50 |     config.sample.num_steps = 50 | ||||||
|     config.sample.batch_size = 8 |     config.sample.batch_size = 16 | ||||||
|     config.sample.num_batches_per_epoch = 4 |     config.sample.num_batches_per_epoch = 2 | ||||||
|  |  | ||||||
|     config.per_prompt_stat_tracking = None |     config.per_prompt_stat_tracking = None | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										196
									
								
								scripts/train.py
									
									
									
									
									
								
							
							
						
						
									
										196
									
								
								scripts/train.py
									
									
									
									
									
								
							| @@ -1,4 +1,6 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  | import contextlib | ||||||
|  | import os | ||||||
| from absl import app, flags, logging | from absl import app, flags, logging | ||||||
| from ml_collections import config_flags | from ml_collections import config_flags | ||||||
| from accelerate import Accelerator | from accelerate import Accelerator | ||||||
| @@ -17,6 +19,8 @@ import torch | |||||||
| import wandb | import wandb | ||||||
| from functools import partial | from functools import partial | ||||||
| import tqdm | import tqdm | ||||||
|  | import tempfile | ||||||
|  | from PIL import Image | ||||||
|  |  | ||||||
| tqdm = partial(tqdm.tqdm, dynamic_ncols=True) | tqdm = partial(tqdm.tqdm, dynamic_ncols=True) | ||||||
|  |  | ||||||
| @@ -46,9 +50,9 @@ def main(_): | |||||||
|     # load scheduler, tokenizer and models. |     # load scheduler, tokenizer and models. | ||||||
|     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) |     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) | ||||||
|     # freeze parameters of models to save more memory |     # freeze parameters of models to save more memory | ||||||
|     pipeline.unet.requires_grad_(False) |  | ||||||
|     pipeline.vae.requires_grad_(False) |     pipeline.vae.requires_grad_(False) | ||||||
|     pipeline.text_encoder.requires_grad_(False) |     pipeline.text_encoder.requires_grad_(False) | ||||||
|  |     pipeline.unet.requires_grad_(not config.use_lora) | ||||||
|     # disable safety checker |     # disable safety checker | ||||||
|     pipeline.safety_checker = None |     pipeline.safety_checker = None | ||||||
|     # make the progress bar nicer |     # make the progress bar nicer | ||||||
| @@ -56,40 +60,47 @@ def main(_): | |||||||
|         position=1, |         position=1, | ||||||
|         disable=not accelerator.is_local_main_process, |         disable=not accelerator.is_local_main_process, | ||||||
|         leave=False, |         leave=False, | ||||||
|  |         desc="Timestep", | ||||||
|  |         dynamic_ncols=True, | ||||||
|     ) |     ) | ||||||
|     # switch to DDIM scheduler |     # switch to DDIM scheduler | ||||||
|     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||||||
|  |  | ||||||
|     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision |     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | ||||||
|     # as these weights are only used for inference, keeping weights in full precision is not required. |     # as these weights are only used for inference, keeping weights in full precision is not required. | ||||||
|     weight_dtype = torch.float32 |     inference_dtype = torch.float32 | ||||||
|     if accelerator.mixed_precision == "fp16": |     if accelerator.mixed_precision == "fp16": | ||||||
|         weight_dtype = torch.float16 |         inference_dtype = torch.float16 | ||||||
|     elif accelerator.mixed_precision == "bf16": |     elif accelerator.mixed_precision == "bf16": | ||||||
|         weight_dtype = torch.bfloat16 |         inference_dtype = torch.bfloat16 | ||||||
|  |  | ||||||
|     # Move unet, vae and text_encoder to device and cast to weight_dtype |     # Move unet, vae and text_encoder to device and cast to inference_dtype | ||||||
|     pipeline.unet.to(accelerator.device, dtype=weight_dtype) |     pipeline.vae.to(accelerator.device, dtype=inference_dtype) | ||||||
|     pipeline.vae.to(accelerator.device, dtype=weight_dtype) |     pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype) | ||||||
|     pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) |     if config.use_lora: | ||||||
|  |         pipeline.unet.to(accelerator.device, dtype=inference_dtype) | ||||||
|  |  | ||||||
|     # Set correct lora layers |     if config.use_lora: | ||||||
|     lora_attn_procs = {} |         # Set correct lora layers | ||||||
|     for name in pipeline.unet.attn_processors.keys(): |         lora_attn_procs = {} | ||||||
|         cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim |         for name in pipeline.unet.attn_processors.keys(): | ||||||
|         if name.startswith("mid_block"): |             cross_attention_dim = ( | ||||||
|             hidden_size = pipeline.unet.config.block_out_channels[-1] |                 None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||||
|         elif name.startswith("up_blocks"): |             ) | ||||||
|             block_id = int(name[len("up_blocks.")]) |             if name.startswith("mid_block"): | ||||||
|             hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] |                 hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||||
|         elif name.startswith("down_blocks"): |             elif name.startswith("up_blocks"): | ||||||
|             block_id = int(name[len("down_blocks.")]) |                 block_id = int(name[len("up_blocks.")]) | ||||||
|             hidden_size = pipeline.unet.config.block_out_channels[block_id] |                 hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||||
|  |             elif name.startswith("down_blocks"): | ||||||
|  |                 block_id = int(name[len("down_blocks.")]) | ||||||
|  |                 hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||||
|  |  | ||||||
|         lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) |             lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||||
|  |         pipeline.unet.set_attn_processor(lora_attn_procs) | ||||||
|     pipeline.unet.set_attn_processor(lora_attn_procs) |         trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||||
|     lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) |     else: | ||||||
|  |         trainable_layers = pipeline.unet | ||||||
|  |  | ||||||
|     # Enable TF32 for faster training on Ampere GPUs, |     # Enable TF32 for faster training on Ampere GPUs, | ||||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices |     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||||
| @@ -110,7 +121,7 @@ def main(_): | |||||||
|         optimizer_cls = torch.optim.AdamW |         optimizer_cls = torch.optim.AdamW | ||||||
|  |  | ||||||
|     optimizer = optimizer_cls( |     optimizer = optimizer_cls( | ||||||
|         lora_layers.parameters(), |         trainable_layers.parameters(), | ||||||
|         lr=config.train.learning_rate, |         lr=config.train.learning_rate, | ||||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), |         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||||
|         weight_decay=config.train.adam_weight_decay, |         weight_decay=config.train.adam_weight_decay, | ||||||
| @@ -121,8 +132,31 @@ def main(_): | |||||||
|     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) |     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) | ||||||
|     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() |     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() | ||||||
|  |  | ||||||
|  |     # generate negative prompt embeddings | ||||||
|  |     neg_prompt_embed = pipeline.text_encoder( | ||||||
|  |         pipeline.tokenizer( | ||||||
|  |             [""], | ||||||
|  |             return_tensors="pt", | ||||||
|  |             padding="max_length", | ||||||
|  |             truncation=True, | ||||||
|  |             max_length=pipeline.tokenizer.model_max_length, | ||||||
|  |         ).input_ids.to(accelerator.device) | ||||||
|  |     )[0] | ||||||
|  |     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||||
|  |     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||||
|  |  | ||||||
|  |     # initialize stat tracker | ||||||
|  |     if config.per_prompt_stat_tracking: | ||||||
|  |         stat_tracker = PerPromptStatTracker( | ||||||
|  |             config.per_prompt_stat_tracking.buffer_size, | ||||||
|  |             config.per_prompt_stat_tracking.min_count, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory | ||||||
|  |     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||||
|  |  | ||||||
|     # Prepare everything with our `accelerator`. |     # Prepare everything with our `accelerator`. | ||||||
|     lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) |     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||||
|  |  | ||||||
|     # Train! |     # Train! | ||||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch |     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||||
| @@ -144,27 +178,10 @@ def main(_): | |||||||
|     assert config.sample.batch_size % config.train.batch_size == 0 |     assert config.sample.batch_size % config.train.batch_size == 0 | ||||||
|     assert samples_per_epoch % total_train_batch_size == 0 |     assert samples_per_epoch % total_train_batch_size == 0 | ||||||
|  |  | ||||||
|     neg_prompt_embed = pipeline.text_encoder( |  | ||||||
|         pipeline.tokenizer( |  | ||||||
|             [""], |  | ||||||
|             return_tensors="pt", |  | ||||||
|             padding="max_length", |  | ||||||
|             truncation=True, |  | ||||||
|             max_length=pipeline.tokenizer.model_max_length, |  | ||||||
|         ).input_ids.to(accelerator.device) |  | ||||||
|     )[0] |  | ||||||
|     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) |  | ||||||
|     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) |  | ||||||
|  |  | ||||||
|     if config.per_prompt_stat_tracking: |  | ||||||
|         stat_tracker = PerPromptStatTracker( |  | ||||||
|             config.per_prompt_stat_tracking.buffer_size, |  | ||||||
|             config.per_prompt_stat_tracking.min_count, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     global_step = 0 |     global_step = 0 | ||||||
|     for epoch in range(config.num_epochs): |     for epoch in range(config.num_epochs): | ||||||
|         #################### SAMPLING #################### |         #################### SAMPLING #################### | ||||||
|  |         pipeline.unet.eval() | ||||||
|         samples = [] |         samples = [] | ||||||
|         prompts = [] |         prompts = [] | ||||||
|         for i in tqdm( |         for i in tqdm( | ||||||
| @@ -189,17 +206,16 @@ def main(_): | |||||||
|             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] |             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] | ||||||
|  |  | ||||||
|             # sample |             # sample | ||||||
|             pipeline.unet.eval() |             with autocast(): | ||||||
|             pipeline.vae.eval() |                 images, _, latents, log_probs = pipeline_with_logprob( | ||||||
|             images, _, latents, log_probs = pipeline_with_logprob( |                     pipeline, | ||||||
|                 pipeline, |                     prompt_embeds=prompt_embeds, | ||||||
|                 prompt_embeds=prompt_embeds, |                     negative_prompt_embeds=sample_neg_prompt_embeds, | ||||||
|                 negative_prompt_embeds=sample_neg_prompt_embeds, |                     num_inference_steps=config.sample.num_steps, | ||||||
|                 num_inference_steps=config.sample.num_steps, |                     guidance_scale=config.sample.guidance_scale, | ||||||
|                 guidance_scale=config.sample.guidance_scale, |                     eta=config.sample.eta, | ||||||
|                 eta=config.sample.eta, |                     output_type="pt", | ||||||
|                 output_type="pt", |                 ) | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) |             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) | ||||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) |             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||||
| @@ -226,14 +242,26 @@ def main(_): | |||||||
|         # gather rewards across processes |         # gather rewards across processes | ||||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() |         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||||
|  |  | ||||||
|         # log sample-related stuff |         # log rewards and images | ||||||
|         accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step) |  | ||||||
|         accelerator.log( |         accelerator.log( | ||||||
|             {"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]}, |             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, | ||||||
|             step=global_step, |             step=global_step, | ||||||
|         ) |         ) | ||||||
|         # from PIL import Image |         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||||
|         # Image.fromarray((images[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save(f"test.png") |         with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|  |             for i, image in enumerate(images): | ||||||
|  |                 pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)) | ||||||
|  |                 pil = pil.resize((256, 256)) | ||||||
|  |                 pil.save(os.path.join(tmpdir, f"{i}.jpg")) | ||||||
|  |             accelerator.log( | ||||||
|  |                 { | ||||||
|  |                     "images": [ | ||||||
|  |                         wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt) | ||||||
|  |                         for i, prompt in enumerate(prompts) | ||||||
|  |                     ], | ||||||
|  |                 }, | ||||||
|  |                 step=global_step, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         # per-prompt mean/std tracking |         # per-prompt mean/std tracking | ||||||
|         if config.per_prompt_stat_tracking: |         if config.per_prompt_stat_tracking: | ||||||
| @@ -271,6 +299,7 @@ def main(_): | |||||||
|             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] |             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||||
|  |  | ||||||
|             # train |             # train | ||||||
|  |             pipeline.unet.train() | ||||||
|             for i, sample in tqdm( |             for i, sample in tqdm( | ||||||
|                 list(enumerate(samples_batched)), |                 list(enumerate(samples_batched)), | ||||||
|                 desc=f"Epoch {epoch}.{inner_epoch}: training", |                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||||
| @@ -286,34 +315,35 @@ def main(_): | |||||||
|                 info = defaultdict(list) |                 info = defaultdict(list) | ||||||
|                 for j in tqdm( |                 for j in tqdm( | ||||||
|                     range(num_timesteps), |                     range(num_timesteps), | ||||||
|                     desc=f"Timestep", |                     desc="Timestep", | ||||||
|                     position=1, |                     position=1, | ||||||
|                     leave=False, |                     leave=False, | ||||||
|                     disable=not accelerator.is_local_main_process, |                     disable=not accelerator.is_local_main_process, | ||||||
|                 ): |                 ): | ||||||
|                     with accelerator.accumulate(pipeline.unet): |                     with accelerator.accumulate(pipeline.unet): | ||||||
|                         if config.train.cfg: |                         with autocast(): | ||||||
|                             noise_pred = pipeline.unet( |                             if config.train.cfg: | ||||||
|                                 torch.cat([sample["latents"][:, j]] * 2), |                                 noise_pred = pipeline.unet( | ||||||
|                                 torch.cat([sample["timesteps"][:, j]] * 2), |                                     torch.cat([sample["latents"][:, j]] * 2), | ||||||
|                                 embeds, |                                     torch.cat([sample["timesteps"][:, j]] * 2), | ||||||
|                             ).sample |                                     embeds, | ||||||
|                             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |                                 ).sample | ||||||
|                             noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( |                                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||||
|                                 noise_pred_text - noise_pred_uncond |                                 noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||||
|  |                                     noise_pred_text - noise_pred_uncond | ||||||
|  |                                 ) | ||||||
|  |                             else: | ||||||
|  |                                 noise_pred = pipeline.unet( | ||||||
|  |                                     sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||||
|  |                                 ).sample | ||||||
|  |                             _, log_prob = ddim_step_with_logprob( | ||||||
|  |                                 pipeline.scheduler, | ||||||
|  |                                 noise_pred, | ||||||
|  |                                 sample["timesteps"][:, j], | ||||||
|  |                                 sample["latents"][:, j], | ||||||
|  |                                 eta=config.sample.eta, | ||||||
|  |                                 prev_sample=sample["next_latents"][:, j], | ||||||
|                             ) |                             ) | ||||||
|                         else: |  | ||||||
|                             noise_pred = pipeline.unet( |  | ||||||
|                                 sample["latents"][:, j], sample["timesteps"][:, j], embeds |  | ||||||
|                             ).sample |  | ||||||
|                         _, log_prob = ddim_step_with_logprob( |  | ||||||
|                             pipeline.scheduler, |  | ||||||
|                             noise_pred, |  | ||||||
|                             sample["timesteps"][:, j], |  | ||||||
|                             sample["latents"][:, j], |  | ||||||
|                             eta=config.sample.eta, |  | ||||||
|                             prev_sample=sample["next_latents"][:, j], |  | ||||||
|                         ) |  | ||||||
|  |  | ||||||
|                         # ppo logic |                         # ppo logic | ||||||
|                         advantages = torch.clamp( |                         advantages = torch.clamp( | ||||||
| @@ -337,7 +367,7 @@ def main(_): | |||||||
|                         # backward pass |                         # backward pass | ||||||
|                         accelerator.backward(loss) |                         accelerator.backward(loss) | ||||||
|                         if accelerator.sync_gradients: |                         if accelerator.sync_gradients: | ||||||
|                             accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) |                             accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) | ||||||
|                         optimizer.step() |                         optimizer.step() | ||||||
|                         optimizer.zero_grad() |                         optimizer.zero_grad() | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user