Working non-lora training; other changes
This commit is contained in:
		
							
								
								
									
										196
									
								
								scripts/train.py
									
									
									
									
									
								
							
							
						
						
									
										196
									
								
								scripts/train.py
									
									
									
									
									
								
							| @@ -1,4 +1,6 @@ | ||||
| from collections import defaultdict | ||||
| import contextlib | ||||
| import os | ||||
| from absl import app, flags, logging | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| @@ -17,6 +19,8 @@ import torch | ||||
| import wandb | ||||
| from functools import partial | ||||
| import tqdm | ||||
| import tempfile | ||||
| from PIL import Image | ||||
|  | ||||
| tqdm = partial(tqdm.tqdm, dynamic_ncols=True) | ||||
|  | ||||
| @@ -46,9 +50,9 @@ def main(_): | ||||
|     # load scheduler, tokenizer and models. | ||||
|     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) | ||||
|     # freeze parameters of models to save more memory | ||||
|     pipeline.unet.requires_grad_(False) | ||||
|     pipeline.vae.requires_grad_(False) | ||||
|     pipeline.text_encoder.requires_grad_(False) | ||||
|     pipeline.unet.requires_grad_(not config.use_lora) | ||||
|     # disable safety checker | ||||
|     pipeline.safety_checker = None | ||||
|     # make the progress bar nicer | ||||
| @@ -56,40 +60,47 @@ def main(_): | ||||
|         position=1, | ||||
|         disable=not accelerator.is_local_main_process, | ||||
|         leave=False, | ||||
|         desc="Timestep", | ||||
|         dynamic_ncols=True, | ||||
|     ) | ||||
|     # switch to DDIM scheduler | ||||
|     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||||
|  | ||||
|     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | ||||
|     # as these weights are only used for inference, keeping weights in full precision is not required. | ||||
|     weight_dtype = torch.float32 | ||||
|     inference_dtype = torch.float32 | ||||
|     if accelerator.mixed_precision == "fp16": | ||||
|         weight_dtype = torch.float16 | ||||
|         inference_dtype = torch.float16 | ||||
|     elif accelerator.mixed_precision == "bf16": | ||||
|         weight_dtype = torch.bfloat16 | ||||
|         inference_dtype = torch.bfloat16 | ||||
|  | ||||
|     # Move unet, vae and text_encoder to device and cast to weight_dtype | ||||
|     pipeline.unet.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.vae.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) | ||||
|     # Move unet, vae and text_encoder to device and cast to inference_dtype | ||||
|     pipeline.vae.to(accelerator.device, dtype=inference_dtype) | ||||
|     pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype) | ||||
|     if config.use_lora: | ||||
|         pipeline.unet.to(accelerator.device, dtype=inference_dtype) | ||||
|  | ||||
|     # Set correct lora layers | ||||
|     lora_attn_procs = {} | ||||
|     for name in pipeline.unet.attn_processors.keys(): | ||||
|         cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||
|         if name.startswith("mid_block"): | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||
|         elif name.startswith("up_blocks"): | ||||
|             block_id = int(name[len("up_blocks.")]) | ||||
|             hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||
|         elif name.startswith("down_blocks"): | ||||
|             block_id = int(name[len("down_blocks.")]) | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||
|     if config.use_lora: | ||||
|         # Set correct lora layers | ||||
|         lora_attn_procs = {} | ||||
|         for name in pipeline.unet.attn_processors.keys(): | ||||
|             cross_attention_dim = ( | ||||
|                 None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||
|             ) | ||||
|             if name.startswith("mid_block"): | ||||
|                 hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||
|             elif name.startswith("up_blocks"): | ||||
|                 block_id = int(name[len("up_blocks.")]) | ||||
|                 hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||
|             elif name.startswith("down_blocks"): | ||||
|                 block_id = int(name[len("down_blocks.")]) | ||||
|                 hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||
|  | ||||
|         lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||
|  | ||||
|     pipeline.unet.set_attn_processor(lora_attn_procs) | ||||
|     lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||
|             lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||
|         pipeline.unet.set_attn_processor(lora_attn_procs) | ||||
|         trainable_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||
|     else: | ||||
|         trainable_layers = pipeline.unet | ||||
|  | ||||
|     # Enable TF32 for faster training on Ampere GPUs, | ||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||
| @@ -110,7 +121,7 @@ def main(_): | ||||
|         optimizer_cls = torch.optim.AdamW | ||||
|  | ||||
|     optimizer = optimizer_cls( | ||||
|         lora_layers.parameters(), | ||||
|         trainable_layers.parameters(), | ||||
|         lr=config.train.learning_rate, | ||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||
|         weight_decay=config.train.adam_weight_decay, | ||||
| @@ -121,8 +132,31 @@ def main(_): | ||||
|     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) | ||||
|     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() | ||||
|  | ||||
|     # generate negative prompt embeddings | ||||
|     neg_prompt_embed = pipeline.text_encoder( | ||||
|         pipeline.tokenizer( | ||||
|             [""], | ||||
|             return_tensors="pt", | ||||
|             padding="max_length", | ||||
|             truncation=True, | ||||
|             max_length=pipeline.tokenizer.model_max_length, | ||||
|         ).input_ids.to(accelerator.device) | ||||
|     )[0] | ||||
|     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||
|     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||
|  | ||||
|     # initialize stat tracker | ||||
|     if config.per_prompt_stat_tracking: | ||||
|         stat_tracker = PerPromptStatTracker( | ||||
|             config.per_prompt_stat_tracking.buffer_size, | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     # for some reason, autocast is necessary for non-lora training but for lora training it uses more memory | ||||
|     autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) | ||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||
|  | ||||
|     # Train! | ||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||
| @@ -144,27 +178,10 @@ def main(_): | ||||
|     assert config.sample.batch_size % config.train.batch_size == 0 | ||||
|     assert samples_per_epoch % total_train_batch_size == 0 | ||||
|  | ||||
|     neg_prompt_embed = pipeline.text_encoder( | ||||
|         pipeline.tokenizer( | ||||
|             [""], | ||||
|             return_tensors="pt", | ||||
|             padding="max_length", | ||||
|             truncation=True, | ||||
|             max_length=pipeline.tokenizer.model_max_length, | ||||
|         ).input_ids.to(accelerator.device) | ||||
|     )[0] | ||||
|     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||
|     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||
|  | ||||
|     if config.per_prompt_stat_tracking: | ||||
|         stat_tracker = PerPromptStatTracker( | ||||
|             config.per_prompt_stat_tracking.buffer_size, | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     global_step = 0 | ||||
|     for epoch in range(config.num_epochs): | ||||
|         #################### SAMPLING #################### | ||||
|         pipeline.unet.eval() | ||||
|         samples = [] | ||||
|         prompts = [] | ||||
|         for i in tqdm( | ||||
| @@ -189,17 +206,16 @@ def main(_): | ||||
|             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] | ||||
|  | ||||
|             # sample | ||||
|             pipeline.unet.eval() | ||||
|             pipeline.vae.eval() | ||||
|             images, _, latents, log_probs = pipeline_with_logprob( | ||||
|                 pipeline, | ||||
|                 prompt_embeds=prompt_embeds, | ||||
|                 negative_prompt_embeds=sample_neg_prompt_embeds, | ||||
|                 num_inference_steps=config.sample.num_steps, | ||||
|                 guidance_scale=config.sample.guidance_scale, | ||||
|                 eta=config.sample.eta, | ||||
|                 output_type="pt", | ||||
|             ) | ||||
|             with autocast(): | ||||
|                 images, _, latents, log_probs = pipeline_with_logprob( | ||||
|                     pipeline, | ||||
|                     prompt_embeds=prompt_embeds, | ||||
|                     negative_prompt_embeds=sample_neg_prompt_embeds, | ||||
|                     num_inference_steps=config.sample.num_steps, | ||||
|                     guidance_scale=config.sample.guidance_scale, | ||||
|                     eta=config.sample.eta, | ||||
|                     output_type="pt", | ||||
|                 ) | ||||
|  | ||||
|             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) | ||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||
| @@ -226,14 +242,26 @@ def main(_): | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|  | ||||
|         # log sample-related stuff | ||||
|         accelerator.log({"reward": rewards, "epoch": epoch}, step=global_step) | ||||
|         # log rewards and images | ||||
|         accelerator.log( | ||||
|             {"images": [wandb.Image(image, caption=prompt) for image, prompt in zip(images, prompts)]}, | ||||
|             {"reward": rewards, "epoch": epoch, "reward_mean": rewards.mean(), "reward_std": rewards.std()}, | ||||
|             step=global_step, | ||||
|         ) | ||||
|         # from PIL import Image | ||||
|         # Image.fromarray((images[0].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save(f"test.png") | ||||
|         # this is a hack to force wandb to log the images as JPEGs instead of PNGs | ||||
|         with tempfile.TemporaryDirectory() as tmpdir: | ||||
|             for i, image in enumerate(images): | ||||
|                 pil = Image.fromarray((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)) | ||||
|                 pil = pil.resize((256, 256)) | ||||
|                 pil.save(os.path.join(tmpdir, f"{i}.jpg")) | ||||
|             accelerator.log( | ||||
|                 { | ||||
|                     "images": [ | ||||
|                         wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt) | ||||
|                         for i, prompt in enumerate(prompts) | ||||
|                     ], | ||||
|                 }, | ||||
|                 step=global_step, | ||||
|             ) | ||||
|  | ||||
|         # per-prompt mean/std tracking | ||||
|         if config.per_prompt_stat_tracking: | ||||
| @@ -271,6 +299,7 @@ def main(_): | ||||
|             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||
|  | ||||
|             # train | ||||
|             pipeline.unet.train() | ||||
|             for i, sample in tqdm( | ||||
|                 list(enumerate(samples_batched)), | ||||
|                 desc=f"Epoch {epoch}.{inner_epoch}: training", | ||||
| @@ -286,34 +315,35 @@ def main(_): | ||||
|                 info = defaultdict(list) | ||||
|                 for j in tqdm( | ||||
|                     range(num_timesteps), | ||||
|                     desc=f"Timestep", | ||||
|                     desc="Timestep", | ||||
|                     position=1, | ||||
|                     leave=False, | ||||
|                     disable=not accelerator.is_local_main_process, | ||||
|                 ): | ||||
|                     with accelerator.accumulate(pipeline.unet): | ||||
|                         if config.train.cfg: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 torch.cat([sample["latents"][:, j]] * 2), | ||||
|                                 torch.cat([sample["timesteps"][:, j]] * 2), | ||||
|                                 embeds, | ||||
|                             ).sample | ||||
|                             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                             noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||
|                                 noise_pred_text - noise_pred_uncond | ||||
|                         with autocast(): | ||||
|                             if config.train.cfg: | ||||
|                                 noise_pred = pipeline.unet( | ||||
|                                     torch.cat([sample["latents"][:, j]] * 2), | ||||
|                                     torch.cat([sample["timesteps"][:, j]] * 2), | ||||
|                                     embeds, | ||||
|                                 ).sample | ||||
|                                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                                 noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||
|                                     noise_pred_text - noise_pred_uncond | ||||
|                                 ) | ||||
|                             else: | ||||
|                                 noise_pred = pipeline.unet( | ||||
|                                     sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                                 ).sample | ||||
|                             _, log_prob = ddim_step_with_logprob( | ||||
|                                 pipeline.scheduler, | ||||
|                                 noise_pred, | ||||
|                                 sample["timesteps"][:, j], | ||||
|                                 sample["latents"][:, j], | ||||
|                                 eta=config.sample.eta, | ||||
|                                 prev_sample=sample["next_latents"][:, j], | ||||
|                             ) | ||||
|                         else: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                             ).sample | ||||
|                         _, log_prob = ddim_step_with_logprob( | ||||
|                             pipeline.scheduler, | ||||
|                             noise_pred, | ||||
|                             sample["timesteps"][:, j], | ||||
|                             sample["latents"][:, j], | ||||
|                             eta=config.sample.eta, | ||||
|                             prev_sample=sample["next_latents"][:, j], | ||||
|                         ) | ||||
|  | ||||
|                         # ppo logic | ||||
|                         advantages = torch.clamp( | ||||
| @@ -337,7 +367,7 @@ def main(_): | ||||
|                         # backward pass | ||||
|                         accelerator.backward(loss) | ||||
|                         if accelerator.sync_gradients: | ||||
|                             accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) | ||||
|                             accelerator.clip_grad_norm_(trainable_layers.parameters(), config.train.max_grad_norm) | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user