Initial commit
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | ||||
| # with the following modifications: | ||||
| # - | ||||
|  | ||||
| from typing import Optional, Tuple, Union | ||||
|  | ||||
| import math | ||||
| import torch | ||||
|  | ||||
| from diffusers.utils import randn_tensor | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||
|  | ||||
|  | ||||
| def ddim_step_with_logprob( | ||||
|     self: DDIMScheduler, | ||||
|     model_output: torch.FloatTensor, | ||||
|     timestep: int, | ||||
|     sample: torch.FloatTensor, | ||||
|     eta: float = 0.0, | ||||
|     use_clipped_model_output: bool = False, | ||||
|     generator=None, | ||||
|     prev_sample: Optional[torch.FloatTensor] = None, | ||||
| ) -> Union[DDIMSchedulerOutput, Tuple]: | ||||
|     """ | ||||
|     Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||||
|     process from the learned model outputs (most often the predicted noise). | ||||
|  | ||||
|     Args: | ||||
|         model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||||
|         timestep (`int`): current discrete timestep in the diffusion chain. | ||||
|         sample (`torch.FloatTensor`): | ||||
|             current instance of sample being created by diffusion process. | ||||
|         eta (`float`): weight of noise for added noise in diffusion step. | ||||
|         use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped | ||||
|             predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when | ||||
|             `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would | ||||
|             coincide with the one provided as input and `use_clipped_model_output` will have not effect. | ||||
|         generator: random number generator. | ||||
|         variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we | ||||
|             can directly provide the noise for the variance itself. This is useful for methods such as | ||||
|             CycleDiffusion. (https://arxiv.org/abs/2210.05559) | ||||
|         return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class | ||||
|  | ||||
|     Returns: | ||||
|         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: | ||||
|         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||||
|         returning a tuple, the first element is the sample tensor. | ||||
|  | ||||
|     """ | ||||
|     assert isinstance(self, DDIMScheduler) | ||||
|     if self.num_inference_steps is None: | ||||
|         raise ValueError( | ||||
|             "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | ||||
|         ) | ||||
|  | ||||
|     # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | ||||
|     # Ideally, read DDIM paper in-detail understanding | ||||
|  | ||||
|     # Notation (<variable name> -> <name in paper> | ||||
|     # - pred_noise_t -> e_theta(x_t, t) | ||||
|     # - pred_original_sample -> f_theta(x_t, t) or x_0 | ||||
|     # - std_dev_t -> sigma_t | ||||
|     # - eta -> η | ||||
|     # - pred_sample_direction -> "direction pointing to x_t" | ||||
|     # - pred_prev_sample -> "x_t-1" | ||||
|  | ||||
|     # 1. get previous step value (=t-1) | ||||
|     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||
|  | ||||
|     # 2. compute alphas, betas | ||||
|     self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) | ||||
|     self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep) | ||||
|     alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) | ||||
|     print(timestep) | ||||
|     print(alpha_prod_t) | ||||
|     print(alpha_prod_t_prev) | ||||
|     print(prev_timestep) | ||||
|  | ||||
|     beta_prod_t = 1 - alpha_prod_t | ||||
|  | ||||
|     # 3. compute predicted original sample from predicted noise also called | ||||
|     # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     if self.config.prediction_type == "epsilon": | ||||
|         pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||||
|         pred_epsilon = model_output | ||||
|     elif self.config.prediction_type == "sample": | ||||
|         pred_original_sample = model_output | ||||
|         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||
|     elif self.config.prediction_type == "v_prediction": | ||||
|         pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||||
|         pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | ||||
|             " `v_prediction`" | ||||
|         ) | ||||
|  | ||||
|     # 4. Clip or threshold "predicted x_0" | ||||
|     if self.config.thresholding: | ||||
|         pred_original_sample = self._threshold_sample(pred_original_sample) | ||||
|     elif self.config.clip_sample: | ||||
|         pred_original_sample = pred_original_sample.clamp( | ||||
|             -self.config.clip_sample_range, self.config.clip_sample_range | ||||
|         ) | ||||
|  | ||||
|     # 5. compute variance: "sigma_t(η)" -> see formula (16) | ||||
|     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||
|     variance = self._get_variance(timestep, prev_timestep) | ||||
|     std_dev_t = eta * variance ** (0.5) | ||||
|  | ||||
|     if use_clipped_model_output: | ||||
|         # the pred_epsilon is always re-derived from the clipped x_0 in Glide | ||||
|         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||
|  | ||||
|     # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | ||||
|  | ||||
|     # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | ||||
|  | ||||
|     if prev_sample is not None and generator is not None: | ||||
|         raise ValueError( | ||||
|             "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" | ||||
|             " `prev_sample` stays `None`." | ||||
|         ) | ||||
|  | ||||
|     if prev_sample is None: | ||||
|         variance_noise = randn_tensor( | ||||
|             model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | ||||
|         ) | ||||
|         prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||||
|  | ||||
|     # log prob of prev_sample given prev_sample_mean and std_dev_t | ||||
|     log_prob = ( | ||||
|         -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) | ||||
|         - torch.log(std_dev_t) | ||||
|         - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||||
|     ) | ||||
|     # mean along all but batch dimension | ||||
|     log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||
|  | ||||
|     return prev_sample, log_prob | ||||
							
								
								
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py | ||||
| # with the following modifications: | ||||
| # - | ||||
|  | ||||
| from typing import Any, Callable, Dict, List, Optional, Union | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( | ||||
|     StableDiffusionPipeline, | ||||
|     rescale_noise_cfg, | ||||
| ) | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMScheduler | ||||
| from .ddim_with_logprob import ddim_step_with_logprob | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def pipeline_with_logprob( | ||||
|     self: StableDiffusionPipeline, | ||||
|     prompt: Union[str, List[str]] = None, | ||||
|     height: Optional[int] = None, | ||||
|     width: Optional[int] = None, | ||||
|     num_inference_steps: int = 50, | ||||
|     guidance_scale: float = 7.5, | ||||
|     negative_prompt: Optional[Union[str, List[str]]] = None, | ||||
|     num_images_per_prompt: Optional[int] = 1, | ||||
|     eta: float = 0.0, | ||||
|     generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||||
|     latents: Optional[torch.FloatTensor] = None, | ||||
|     prompt_embeds: Optional[torch.FloatTensor] = None, | ||||
|     negative_prompt_embeds: Optional[torch.FloatTensor] = None, | ||||
|     output_type: Optional[str] = "pil", | ||||
|     return_dict: bool = True, | ||||
|     callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||||
|     callback_steps: int = 1, | ||||
|     cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||||
|     guidance_rescale: float = 0.0, | ||||
| ): | ||||
|     r""" | ||||
|     Function invoked when calling the pipeline for generation. | ||||
|  | ||||
|     Args: | ||||
|         prompt (`str` or `List[str]`, *optional*): | ||||
|             The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | ||||
|             instead. | ||||
|         height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||
|             The height in pixels of the generated image. | ||||
|         width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||
|             The width in pixels of the generated image. | ||||
|         num_inference_steps (`int`, *optional*, defaults to 50): | ||||
|             The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||||
|             expense of slower inference. | ||||
|         guidance_scale (`float`, *optional*, defaults to 7.5): | ||||
|             Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||||
|             `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||||
|             Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||||
|             1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||||
|             usually at the expense of lower image quality. | ||||
|         negative_prompt (`str` or `List[str]`, *optional*): | ||||
|             The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||||
|             `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | ||||
|             less than `1`). | ||||
|         num_images_per_prompt (`int`, *optional*, defaults to 1): | ||||
|             The number of images to generate per prompt. | ||||
|         eta (`float`, *optional*, defaults to 0.0): | ||||
|             Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||||
|             [`schedulers.DDIMScheduler`], will be ignored for others. | ||||
|         generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | ||||
|             One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | ||||
|             to make generation deterministic. | ||||
|         latents (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||||
|             generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||||
|             tensor will ge generated by sampling using the supplied random `generator`. | ||||
|         prompt_embeds (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | ||||
|             provided, text embeddings will be generated from `prompt` input argument. | ||||
|         negative_prompt_embeds (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||||
|             weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | ||||
|             argument. | ||||
|         output_type (`str`, *optional*, defaults to `"pil"`): | ||||
|             The output format of the generate image. Choose between | ||||
|             [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||||
|         return_dict (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||||
|             plain tuple. | ||||
|         callback (`Callable`, *optional*): | ||||
|             A function that will be called every `callback_steps` steps during inference. The function will be | ||||
|             called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | ||||
|         callback_steps (`int`, *optional*, defaults to 1): | ||||
|             The frequency at which the `callback` function will be called. If not specified, the callback will be | ||||
|             called at every step. | ||||
|         cross_attention_kwargs (`dict`, *optional*): | ||||
|             A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||||
|             `self.processor` in | ||||
|             [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||||
|         guidance_rescale (`float`, *optional*, defaults to 0.7): | ||||
|             Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are | ||||
|             Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of | ||||
|             [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). | ||||
|             Guidance rescale factor should fix overexposure when using zero terminal SNR. | ||||
|  | ||||
|     Examples: | ||||
|  | ||||
|     Returns: | ||||
|         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||||
|         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||||
|         When returning a tuple, the first element is a list with the generated images, and the second element is a | ||||
|         list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||||
|         (nsfw) content, according to the `safety_checker`. | ||||
|     """ | ||||
|     # 0. Default height and width to unet | ||||
|     height = height or self.unet.config.sample_size * self.vae_scale_factor | ||||
|     width = width or self.unet.config.sample_size * self.vae_scale_factor | ||||
|  | ||||
|     # 1. Check inputs. Raise error if not correct | ||||
|     self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) | ||||
|  | ||||
|     # 2. Define call parameters | ||||
|     if prompt is not None and isinstance(prompt, str): | ||||
|         batch_size = 1 | ||||
|     elif prompt is not None and isinstance(prompt, list): | ||||
|         batch_size = len(prompt) | ||||
|     else: | ||||
|         batch_size = prompt_embeds.shape[0] | ||||
|  | ||||
|     device = self._execution_device | ||||
|     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||||
|     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||||
|     # corresponds to doing no classifier free guidance. | ||||
|     do_classifier_free_guidance = guidance_scale > 1.0 | ||||
|  | ||||
|     # 3. Encode input prompt | ||||
|     text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | ||||
|     prompt_embeds = self._encode_prompt( | ||||
|         prompt, | ||||
|         device, | ||||
|         num_images_per_prompt, | ||||
|         do_classifier_free_guidance, | ||||
|         negative_prompt, | ||||
|         prompt_embeds=prompt_embeds, | ||||
|         negative_prompt_embeds=negative_prompt_embeds, | ||||
|         lora_scale=text_encoder_lora_scale, | ||||
|     ) | ||||
|  | ||||
|     # 4. Prepare timesteps | ||||
|     self.scheduler.set_timesteps(num_inference_steps, device=device) | ||||
|     timesteps = self.scheduler.timesteps | ||||
|  | ||||
|     # 5. Prepare latent variables | ||||
|     num_channels_latents = self.unet.config.in_channels | ||||
|     latents = self.prepare_latents( | ||||
|         batch_size * num_images_per_prompt, | ||||
|         num_channels_latents, | ||||
|         height, | ||||
|         width, | ||||
|         prompt_embeds.dtype, | ||||
|         device, | ||||
|         generator, | ||||
|         latents, | ||||
|     ) | ||||
|  | ||||
|     # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | ||||
|     extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | ||||
|  | ||||
|     # 7. Denoising loop | ||||
|     num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||||
|     all_latents = [latents] | ||||
|     all_log_probs = [] | ||||
|     with self.progress_bar(total=num_inference_steps) as progress_bar: | ||||
|         for i, t in enumerate(timesteps): | ||||
|             # expand the latents if we are doing classifier free guidance | ||||
|             latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||||
|             latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||||
|  | ||||
|             # predict the noise residual | ||||
|             noise_pred = self.unet( | ||||
|                 latent_model_input, | ||||
|                 t, | ||||
|                 encoder_hidden_states=prompt_embeds, | ||||
|                 cross_attention_kwargs=cross_attention_kwargs, | ||||
|                 return_dict=False, | ||||
|             )[0] | ||||
|  | ||||
|             # perform guidance | ||||
|             if do_classifier_free_guidance: | ||||
|                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||||
|  | ||||
|             if do_classifier_free_guidance and guidance_rescale > 0.0: | ||||
|                 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||||
|                 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | ||||
|  | ||||
|             # compute the previous noisy sample x_t -> x_t-1 | ||||
|             latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) | ||||
|  | ||||
|             all_latents.append(latents) | ||||
|             all_log_probs.append(log_prob) | ||||
|  | ||||
|             # call the callback, if provided | ||||
|             if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||||
|                 progress_bar.update() | ||||
|                 if callback is not None and i % callback_steps == 0: | ||||
|                     callback(i, t, latents) | ||||
|  | ||||
|     if not output_type == "latent": | ||||
|         image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||||
|         image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||||
|     else: | ||||
|         image = latents | ||||
|         has_nsfw_concept = None | ||||
|  | ||||
|     if has_nsfw_concept is None: | ||||
|         do_denormalize = [True] * image.shape[0] | ||||
|     else: | ||||
|         do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | ||||
|  | ||||
|     image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||||
|  | ||||
|     # Offload last model to CPU | ||||
|     if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||||
|         self.final_offload_hook.offload() | ||||
|  | ||||
|     return image, has_nsfw_concept, all_latents, all_log_probs | ||||
							
								
								
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| from importlib import resources | ||||
| import functools | ||||
| import random | ||||
| import inflect | ||||
|  | ||||
| IE = inflect.engine() | ||||
| ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||
|  | ||||
|  | ||||
| @functools.cache | ||||
| def load_lines(name): | ||||
|     with ASSETS_PATH.joinpath(name).open() as f: | ||||
|         return [line.strip() for line in f.readlines()] | ||||
|  | ||||
|  | ||||
| def imagenet(low, high): | ||||
|     return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} | ||||
|  | ||||
|  | ||||
| def imagenet_all(): | ||||
|     return imagenet(0, 1000) | ||||
|  | ||||
|  | ||||
| def imagenet_animals(): | ||||
|     return imagenet(0, 398) | ||||
|  | ||||
|  | ||||
| def imagenet_dogs(): | ||||
|     return imagenet(151, 269) | ||||
|  | ||||
|  | ||||
| def nouns_activities(nouns_file, activities_file): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     activities = load_lines(activities_file) | ||||
|     return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} | ||||
|  | ||||
|  | ||||
| def counting(nouns_file, low, high): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     number = IE.number_to_words(random.randint(low, high)) | ||||
|     noun = random.choice(nouns) | ||||
|     plural_noun = IE.plural(noun) | ||||
|     prompt = f"{number} {plural_noun}" | ||||
|     metadata = { | ||||
|         "questions": [ | ||||
|             f"How many {plural_noun} are there in this image?", | ||||
|             f"What animal is in this image?", | ||||
|         ], | ||||
|         "answers": [ | ||||
|             number, | ||||
|             noun, | ||||
|         ], | ||||
|     } | ||||
|     return prompt, metadata | ||||
							
								
								
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from PIL import Image | ||||
| import io | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def jpeg_incompressibility(): | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC | ||||
|         images = [Image.fromarray(image) for image in images] | ||||
|         buffers = [io.BytesIO() for _ in images] | ||||
|         for image, buffer in zip(images, buffers): | ||||
|             image.save(buffer, format="JPEG", quality=95) | ||||
|         sizes = [buffer.tell() / 1000 for buffer in buffers] | ||||
|         return np.array(sizes), {} | ||||
|  | ||||
|     return _fn | ||||
|  | ||||
|  | ||||
| def jpeg_compressibility(): | ||||
|     jpeg_fn = jpeg_incompressibility() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         rew, meta = jpeg_fn(images, prompts, metadata) | ||||
|         return -rew, meta | ||||
|  | ||||
|     return _fn | ||||
							
								
								
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| import numpy as np | ||||
| from collections import deque | ||||
|  | ||||
|  | ||||
| class PerPromptStatTracker: | ||||
|     def __init__(self, buffer_size, min_count): | ||||
|         self.buffer_size = buffer_size | ||||
|         self.min_count = min_count | ||||
|         self.stats = {} | ||||
|  | ||||
|     def update(self, prompts, rewards): | ||||
|         unique = np.unique(prompts) | ||||
|         advantages = np.empty_like(rewards) | ||||
|         for prompt in unique: | ||||
|             prompt_rewards = rewards[prompts == prompt] | ||||
|             if prompt not in self.stats: | ||||
|                 self.stats[prompt] = deque(maxlen=self.buffer_size) | ||||
|             self.stats[prompt].extend(prompt_rewards) | ||||
|  | ||||
|             if len(self.stats[prompt]) < self.min_count: | ||||
|                 mean = np.mean(rewards) | ||||
|                 std = np.std(rewards) + 1e-6 | ||||
|             else: | ||||
|                 mean = np.mean(self.stats[prompt]) | ||||
|                 std = np.std(self.stats[prompt]) + 1e-6 | ||||
|             advantages[prompts == prompt] = (prompt_rewards - mean) / std | ||||
|  | ||||
|         return advantages | ||||
|  | ||||
|     def get_stats(self): | ||||
|         return { | ||||
|             k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} | ||||
|             for k, v in self.stats.items() | ||||
|         } | ||||
		Reference in New Issue
	
	Block a user