Initial commit
This commit is contained in:
		
							
								
								
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,341 @@ | ||||
| from absl import app, flags, logging | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed | ||||
| from accelerate.logging import get_logger | ||||
| from diffusers import StableDiffusionPipeline, DDIMScheduler | ||||
| from diffusers.loaders import AttnProcsLayers | ||||
| from diffusers.models.attention_processor import LoRAAttnProcessor | ||||
| import ddpo_pytorch.prompts | ||||
| import ddpo_pytorch.rewards | ||||
| from ddpo_pytorch.stat_tracking import PerPromptStatTracker | ||||
| from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob | ||||
| from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob | ||||
| import torch | ||||
| import tqdm | ||||
|  | ||||
|  | ||||
| FLAGS = flags.FLAGS | ||||
| config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|     accelerator = Accelerator( | ||||
|         log_with="all", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_dir=config.logdir, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config) | ||||
|     logger.info(config) | ||||
|  | ||||
|     # set seed | ||||
|     set_seed(config.seed) | ||||
|  | ||||
|     # load scheduler, tokenizer and models. | ||||
|     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) | ||||
|     # freeze parameters of models to save more memory | ||||
|     pipeline.unet.requires_grad_(False) | ||||
|     pipeline.vae.requires_grad_(False) | ||||
|     pipeline.text_encoder.requires_grad_(False) | ||||
|     # disable safety checker | ||||
|     pipeline.safety_checker = None | ||||
|     # make the progress bar nicer | ||||
|     pipeline.set_progress_bar_config( | ||||
|         position=1, | ||||
|         disable=not accelerator.is_local_main_process, | ||||
|         leave=False, | ||||
|     ) | ||||
|     # switch to DDIM scheduler | ||||
|     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||||
|  | ||||
|     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | ||||
|     # as these weights are only used for inference, keeping weights in full precision is not required. | ||||
|     weight_dtype = torch.float32 | ||||
|     if accelerator.mixed_precision == "fp16": | ||||
|         weight_dtype = torch.float16 | ||||
|     elif accelerator.mixed_precision == "bf16": | ||||
|         weight_dtype = torch.bfloat16 | ||||
|  | ||||
|     # Move unet, vae and text_encoder to device and cast to weight_dtype | ||||
|     pipeline.unet.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.vae.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) | ||||
|  | ||||
|     # Set correct lora layers | ||||
|     lora_attn_procs = {} | ||||
|     for name in pipeline.unet.attn_processors.keys(): | ||||
|         cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||
|         if name.startswith("mid_block"): | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||
|         elif name.startswith("up_blocks"): | ||||
|             block_id = int(name[len("up_blocks.")]) | ||||
|             hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||
|         elif name.startswith("down_blocks"): | ||||
|             block_id = int(name[len("down_blocks.")]) | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||
|  | ||||
|         lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||
|  | ||||
|     pipeline.unet.set_attn_processor(lora_attn_procs) | ||||
|     lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||
|  | ||||
|     # Enable TF32 for faster training on Ampere GPUs, | ||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||
|     if config.allow_tf32: | ||||
|         torch.backends.cuda.matmul.allow_tf32 = True | ||||
|  | ||||
|     if config.train.scale_lr: | ||||
|         config.train.learning_rate = ( | ||||
|             config.train.learning_rate | ||||
|             * config.train.gradient_accumulation_steps | ||||
|             * config.train.batch_size | ||||
|             * accelerator.num_processes | ||||
|         ) | ||||
|  | ||||
|     # Initialize the optimizer | ||||
|     if config.train.use_8bit_adam: | ||||
|         try: | ||||
|             import bitsandbytes as bnb | ||||
|         except ImportError: | ||||
|             raise ImportError( | ||||
|                 "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | ||||
|             ) | ||||
|  | ||||
|         optimizer_cls = bnb.optim.AdamW8bit | ||||
|     else: | ||||
|         optimizer_cls = torch.optim.AdamW | ||||
|  | ||||
|     optimizer = optimizer_cls( | ||||
|         lora_layers.parameters(), | ||||
|         lr=config.train.learning_rate, | ||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||
|         weight_decay=config.train.adam_weight_decay, | ||||
|         eps=config.train.adam_epsilon, | ||||
|     ) | ||||
|  | ||||
|     # prepare prompt and reward fn | ||||
|     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) | ||||
|     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) | ||||
|  | ||||
|     # Train! | ||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||
|     total_train_batch_size = ( | ||||
|         config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps | ||||
|     ) | ||||
|  | ||||
|     assert config.sample.batch_size % config.train.batch_size == 0 | ||||
|     assert samples_per_epoch % total_train_batch_size == 0 | ||||
|  | ||||
|     logger.info("***** Running training *****") | ||||
|     logger.info(f"  Num Epochs = {config.num_epochs}") | ||||
|     logger.info(f"  Sample batch size per device = {config.sample.batch_size}") | ||||
|     logger.info(f"  Train batch size per device = {config.train.batch_size}") | ||||
|     logger.info(f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") | ||||
|     logger.info("") | ||||
|     logger.info(f"  Total number of samples per epoch = {samples_per_epoch}") | ||||
|     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") | ||||
|     logger.info(f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") | ||||
|     logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}") | ||||
|  | ||||
|     neg_prompt_embed = pipeline.text_encoder( | ||||
|         pipeline.tokenizer( | ||||
|             [""], | ||||
|             return_tensors="pt", | ||||
|             padding="max_length", | ||||
|             truncation=True, | ||||
|             max_length=pipeline.tokenizer.model_max_length, | ||||
|         ).input_ids.to(accelerator.device) | ||||
|     )[0] | ||||
|     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||
|     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||
|  | ||||
|     if config.per_prompt_stat_tracking: | ||||
|         stat_tracker = PerPromptStatTracker( | ||||
|             config.per_prompt_stat_tracking.buffer_size, | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     for epoch in range(config.num_epochs): | ||||
|         #################### SAMPLING #################### | ||||
|         samples = [] | ||||
|         prompts = [] | ||||
|         for i in tqdm.tqdm( | ||||
|             range(config.sample.num_batches_per_epoch), | ||||
|             desc=f"Epoch {epoch}: sampling", | ||||
|             disable=not accelerator.is_local_main_process, | ||||
|             position=0, | ||||
|         ): | ||||
|             # generate prompts | ||||
|             prompts, prompt_metadata = zip( | ||||
|                 *[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)] | ||||
|             ) | ||||
|  | ||||
|             # encode prompts | ||||
|             prompt_ids = pipeline.tokenizer( | ||||
|                 prompts, | ||||
|                 return_tensors="pt", | ||||
|                 padding="max_length", | ||||
|                 truncation=True, | ||||
|                 max_length=pipeline.tokenizer.model_max_length, | ||||
|             ).input_ids.to(accelerator.device) | ||||
|             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] | ||||
|  | ||||
|             # sample | ||||
|             pipeline.unet.eval() | ||||
|             pipeline.vae.eval() | ||||
|             images, _, latents, log_probs = pipeline_with_logprob( | ||||
|                 pipeline, | ||||
|                 prompt_embeds=prompt_embeds, | ||||
|                 negative_prompt_embeds=sample_neg_prompt_embeds, | ||||
|                 num_inference_steps=config.sample.num_steps, | ||||
|                 guidance_scale=config.sample.guidance_scale, | ||||
|                 eta=config.sample.eta, | ||||
|                 output_type="pt", | ||||
|             ) | ||||
|  | ||||
|             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) | ||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||
|             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||
|  | ||||
|             # compute rewards | ||||
|             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) | ||||
|  | ||||
|             samples.append( | ||||
|                 { | ||||
|                     "prompt_ids": prompt_ids, | ||||
|                     "prompt_embeds": prompt_embeds, | ||||
|                     "timesteps": timesteps, | ||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||
|                     "log_probs": log_probs, | ||||
|                     "rewards": torch.as_tensor(rewards), | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|  | ||||
|         # per-prompt mean/std tracking | ||||
|         if config.per_prompt_stat_tracking: | ||||
|             # gather the prompts across processes | ||||
|             prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() | ||||
|             prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) | ||||
|             advantages = stat_tracker.update(prompts, rewards) | ||||
|         else: | ||||
|             advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | ||||
|  | ||||
|         # ungather advantages; we only need to keep the entries corresponding to the samples on this process | ||||
|         samples["advantages"] = ( | ||||
|             torch.as_tensor(advantages) | ||||
|             .reshape(accelerator.num_processes, -1)[accelerator.process_index] | ||||
|             .to(accelerator.device) | ||||
|         ) | ||||
|  | ||||
|         del samples["rewards"] | ||||
|         del samples["prompt_ids"] | ||||
|  | ||||
|         total_batch_size, num_timesteps = samples["timesteps"].shape | ||||
|         assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch | ||||
|         assert num_timesteps == config.sample.num_steps | ||||
|  | ||||
|         #################### TRAINING #################### | ||||
|         for inner_epoch in range(config.train.num_inner_epochs): | ||||
|             # shuffle samples along batch dimension | ||||
|             indices = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[indices] for k, v in samples.items()} | ||||
|  | ||||
|             # shuffle along time dimension, independently for each sample | ||||
|             for i in range(total_batch_size): | ||||
|                 indices = torch.randperm(num_timesteps, device=accelerator.device) | ||||
|                 for key in ["timesteps", "latents", "next_latents"]: | ||||
|                     samples[key][i] = samples[key][i][indices] | ||||
|  | ||||
|             # rebatch for training | ||||
|             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||
|  | ||||
|             # dict of lists -> list of dicts for easier iteration | ||||
|             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||
|  | ||||
|             # train | ||||
|             for i, sample in tqdm.tqdm( | ||||
|                 list(enumerate(samples_batched)), | ||||
|                 desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", | ||||
|                 position=0, | ||||
|             ): | ||||
|                 if config.train.cfg: | ||||
|                     # concat negative prompts to sample prompts to avoid two forward passes | ||||
|                     embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]]) | ||||
|                 else: | ||||
|                     embeds = sample["prompt_embeds"] | ||||
|  | ||||
|                 for j in tqdm.trange( | ||||
|                     num_timesteps, | ||||
|                     desc=f"Timestep", | ||||
|                     position=1, | ||||
|                     leave=False, | ||||
|                 ): | ||||
|                     with accelerator.accumulate(pipeline.unet): | ||||
|                         if config.train.cfg: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 torch.cat([sample["latents"][:, j]] * 2), | ||||
|                                 torch.cat([sample["timesteps"][:, j]] * 2), | ||||
|                                 embeds, | ||||
|                             ).sample | ||||
|                             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                             noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||
|                                 noise_pred_text - noise_pred_uncond | ||||
|                             ) | ||||
|                         else: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                             ).sample | ||||
|                         _, log_prob = ddim_step_with_logprob( | ||||
|                             pipeline.scheduler, | ||||
|                             noise_pred, | ||||
|                             sample["timesteps"][:, j], | ||||
|                             sample["latents"][:, j], | ||||
|                             eta=config.sample.eta, | ||||
|                             prev_sample=sample["next_latents"][:, j], | ||||
|                         ) | ||||
|  | ||||
|                         # ppo logic | ||||
|                         advantages = torch.clamp( | ||||
|                             sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max | ||||
|                         ) | ||||
|                         ratio = torch.exp(log_prob - sample["log_probs"][:, j]) | ||||
|                         unclipped_loss = -advantages * ratio | ||||
|                         clipped_loss = -advantages * torch.clamp( | ||||
|                             ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range | ||||
|                         ) | ||||
|                         loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | ||||
|  | ||||
|                         # debugging values | ||||
|                         info = {} | ||||
|                         # John Schulman says that (ratio - 1) - log(ratio) is a better | ||||
|                         # estimator, but most existing code uses this so... | ||||
|                         # http://joschu.net/blog/kl-approx.html | ||||
|                         info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) | ||||
|                         info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) | ||||
|                         info["loss"] = loss | ||||
|  | ||||
|                         # backward pass | ||||
|                         accelerator.backward(loss) | ||||
|                         if accelerator.sync_gradients: | ||||
|                             accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     app.run(main) | ||||
		Reference in New Issue
	
	Block a user