Initial commit

This commit is contained in:
Kevin Black 2023-06-23 19:25:54 -07:00
commit 2fda3d4e78
17 changed files with 2198 additions and 0 deletions

305
.gitignore vendored Normal file
View File

@ -0,0 +1,305 @@
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim
### Intellij+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### Intellij+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.
.idea/*
!.idea/codeStyles
!.idea/runConfigurations
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session
Session.vim
Sessionx.vim
# Temporary
.netrwhist
*~
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim

1
README.md Normal file
View File

@ -0,0 +1 @@
# ddpo-pytorch

Binary file not shown.

56
config/base.py Normal file
View File

@ -0,0 +1,56 @@
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
# misc
config.seed = 42
config.logdir = "logs"
config.num_epochs = 100
config.mixed_precision = "fp16"
config.allow_tf32 = True
# pretrained model initialization
config.pretrained = pretrained = ml_collections.ConfigDict()
pretrained.model = "runwayml/stable-diffusion-v1-5"
pretrained.revision = "main"
# training
config.train = train = ml_collections.ConfigDict()
train.mixed_precision = "fp16"
train.batch_size = 1
train.use_8bit_adam = False
train.scale_lr = False
train.learning_rate = 1e-4
train.adam_beta1 = 0.9
train.adam_beta2 = 0.999
train.adam_weight_decay = 1e-2
train.adam_epsilon = 1e-8
train.gradient_accumulation_steps = 1
train.max_grad_norm = 1.0
train.num_inner_epochs = 1
train.cfg = True
train.adv_clip_max = 10
train.clip_range = 1e-4
# sampling
config.sample = sample = ml_collections.ConfigDict()
sample.num_steps = 5
sample.eta = 1.0
sample.guidance_scale = 5.0
sample.batch_size = 1
sample.num_batches_per_epoch = 4
# prompting
config.prompt_fn = "imagenet_animals"
config.prompt_fn_kwargs = {}
# rewards
config.reward_fn = "jpeg_compressibility"
config.per_prompt_stat_tracking = ml_collections.ConfigDict()
config.per_prompt_stat_tracking.buffer_size = 128
config.per_prompt_stat_tracking.min_count = 16
return config

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,143 @@
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py
# with the following modifications:
# -
from typing import Optional, Tuple, Union
import math
import torch
from diffusers.utils import randn_tensor
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
def ddim_step_with_logprob(
self: DDIMScheduler,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
prev_sample: Optional[torch.FloatTensor] = None,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
assert isinstance(self, DDIMScheduler)
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
self.alphas_cumprod = self.alphas_cumprod.to(timestep.device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device)
alpha_prod_t = self.alphas_cumprod.gather(0, timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod)
print(timestep)
print(alpha_prod_t)
print(alpha_prod_t_prev)
print(prev_timestep)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if prev_sample is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
" `prev_sample` stays `None`."
)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
# log prob of prev_sample given prev_sample_mean and std_dev_t
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
- torch.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, log_prob

View File

@ -0,0 +1,225 @@
# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
# with the following modifications:
# -
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
rescale_noise_cfg,
)
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from .ddim_with_logprob import ddim_step_with_logprob
@torch.no_grad()
def pipeline_with_logprob(
self: StableDiffusionPipeline,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
all_latents = [latents]
all_log_probs = []
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs)
all_latents.append(latents)
all_log_probs.append(log_prob)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
return image, has_nsfw_concept, all_latents, all_log_probs

54
ddpo_pytorch/prompts.py Normal file
View File

@ -0,0 +1,54 @@
from importlib import resources
import functools
import random
import inflect
IE = inflect.engine()
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@functools.cache
def load_lines(name):
with ASSETS_PATH.joinpath(name).open() as f:
return [line.strip() for line in f.readlines()]
def imagenet(low, high):
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
def imagenet_all():
return imagenet(0, 1000)
def imagenet_animals():
return imagenet(0, 398)
def imagenet_dogs():
return imagenet(151, 269)
def nouns_activities(nouns_file, activities_file):
nouns = load_lines(nouns_file)
activities = load_lines(activities_file)
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
def counting(nouns_file, low, high):
nouns = load_lines(nouns_file)
number = IE.number_to_words(random.randint(low, high))
noun = random.choice(nouns)
plural_noun = IE.plural(noun)
prompt = f"{number} {plural_noun}"
metadata = {
"questions": [
f"How many {plural_noun} are there in this image?",
f"What animal is in this image?",
],
"answers": [
number,
noun,
],
}
return prompt, metadata

29
ddpo_pytorch/rewards.py Normal file
View File

@ -0,0 +1,29 @@
from PIL import Image
import io
import numpy as np
import torch
def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
image.save(buffer, format="JPEG", quality=95)
sizes = [buffer.tell() / 1000 for buffer in buffers]
return np.array(sizes), {}
return _fn
def jpeg_compressibility():
jpeg_fn = jpeg_incompressibility()
def _fn(images, prompts, metadata):
rew, meta = jpeg_fn(images, prompts, metadata)
return -rew, meta
return _fn

View File

@ -0,0 +1,34 @@
import numpy as np
from collections import deque
class PerPromptStatTracker:
def __init__(self, buffer_size, min_count):
self.buffer_size = buffer_size
self.min_count = min_count
self.stats = {}
def update(self, prompts, rewards):
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = deque(maxlen=self.buffer_size)
self.stats[prompt].extend(prompt_rewards)
if len(self.stats[prompt]) < self.min_count:
mean = np.mean(rewards)
std = np.std(rewards) + 1e-6
else:
mean = np.mean(self.stats[prompt])
std = np.std(self.stats[prompt]) + 1e-6
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
def get_stats(self):
return {
k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)}
for k, v in self.stats.items()
}

341
scripts/train.py Normal file
View File

@ -0,0 +1,341 @@
from absl import app, flags, logging
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import ddpo_pytorch.prompts
import ddpo_pytorch.rewards
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob
from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob
import torch
import tqdm
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")
logger = get_logger(__name__)
def main(_):
# basic Accelerate and logging setup
config = FLAGS.config
accelerator = Accelerator(
log_with="all",
mixed_precision=config.mixed_precision,
project_dir=config.logdir,
)
if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config)
logger.info(config)
# set seed
set_seed(config.seed)
# load scheduler, tokenizer and models.
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
# freeze parameters of models to save more memory
pipeline.unet.requires_grad_(False)
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
# disable safety checker
pipeline.safety_checker = None
# make the progress bar nicer
pipeline.set_progress_bar_config(
position=1,
disable=not accelerator.is_local_main_process,
leave=False,
)
# switch to DDIM scheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
pipeline.unet.to(accelerator.device, dtype=weight_dtype)
pipeline.vae.to(accelerator.device, dtype=weight_dtype)
pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype)
# Set correct lora layers
lora_attn_procs = {}
for name in pipeline.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipeline.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
pipeline.unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(pipeline.unet.attn_processors)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if config.train.scale_lr:
config.train.learning_rate = (
config.train.learning_rate
* config.train.gradient_accumulation_steps
* config.train.batch_size
* accelerator.num_processes
)
# Initialize the optimizer
if config.train.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers.parameters(),
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
eps=config.train.adam_epsilon,
)
# prepare prompt and reward fn
prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn)
reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)()
# Prepare everything with our `accelerator`.
lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer)
# Train!
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
total_train_batch_size = (
config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps
)
assert config.sample.batch_size % config.train.batch_size == 0
assert samples_per_epoch % total_train_batch_size == 0
logger.info("***** Running training *****")
logger.info(f" Num Epochs = {config.num_epochs}")
logger.info(f" Sample batch size per device = {config.sample.batch_size}")
logger.info(f" Train batch size per device = {config.train.batch_size}")
logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}")
logger.info("")
logger.info(f" Total number of samples per epoch = {samples_per_epoch}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
neg_prompt_embed = pipeline.text_encoder(
pipeline.tokenizer(
[""],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=pipeline.tokenizer.model_max_length,
).input_ids.to(accelerator.device)
)[0]
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1)
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
if config.per_prompt_stat_tracking:
stat_tracker = PerPromptStatTracker(
config.per_prompt_stat_tracking.buffer_size,
config.per_prompt_stat_tracking.min_count,
)
for epoch in range(config.num_epochs):
#################### SAMPLING ####################
samples = []
prompts = []
for i in tqdm.tqdm(
range(config.sample.num_batches_per_epoch),
desc=f"Epoch {epoch}: sampling",
disable=not accelerator.is_local_main_process,
position=0,
):
# generate prompts
prompts, prompt_metadata = zip(
*[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)]
)
# encode prompts
prompt_ids = pipeline.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=pipeline.tokenizer.model_max_length,
).input_ids.to(accelerator.device)
prompt_embeds = pipeline.text_encoder(prompt_ids)[0]
# sample
pipeline.unet.eval()
pipeline.vae.eval()
images, _, latents, log_probs = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=config.sample.num_steps,
guidance_scale=config.sample.guidance_scale,
eta=config.sample.eta,
output_type="pt",
)
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64)
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
# compute rewards
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata)
samples.append(
{
"prompt_ids": prompt_ids,
"prompt_embeds": prompt_embeds,
"timesteps": timesteps,
"latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs,
"rewards": torch.as_tensor(rewards),
}
)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
# gather rewards across processes
rewards = accelerator.gather(samples["rewards"]).cpu().numpy()
# per-prompt mean/std tracking
if config.per_prompt_stat_tracking:
# gather the prompts across processes
prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
advantages = stat_tracker.update(prompts, rewards)
else:
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# ungather advantages; we only need to keep the entries corresponding to the samples on this process
samples["advantages"] = (
torch.as_tensor(advantages)
.reshape(accelerator.num_processes, -1)[accelerator.process_index]
.to(accelerator.device)
)
del samples["rewards"]
del samples["prompt_ids"]
total_batch_size, num_timesteps = samples["timesteps"].shape
assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch
assert num_timesteps == config.sample.num_steps
#################### TRAINING ####################
for inner_epoch in range(config.train.num_inner_epochs):
# shuffle samples along batch dimension
indices = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[indices] for k, v in samples.items()}
# shuffle along time dimension, independently for each sample
for i in range(total_batch_size):
indices = torch.randperm(num_timesteps, device=accelerator.device)
for key in ["timesteps", "latents", "next_latents"]:
samples[key][i] = samples[key][i][indices]
# rebatch for training
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
# dict of lists -> list of dicts for easier iteration
samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
# train
for i, sample in tqdm.tqdm(
list(enumerate(samples_batched)),
desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training",
position=0,
):
if config.train.cfg:
# concat negative prompts to sample prompts to avoid two forward passes
embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]])
else:
embeds = sample["prompt_embeds"]
for j in tqdm.trange(
num_timesteps,
desc=f"Timestep",
position=1,
leave=False,
):
with accelerator.accumulate(pipeline.unet):
if config.train.cfg:
noise_pred = pipeline.unet(
torch.cat([sample["latents"][:, j]] * 2),
torch.cat([sample["timesteps"][:, j]] * 2),
embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + config.sample.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
noise_pred = pipeline.unet(
sample["latents"][:, j], sample["timesteps"][:, j], embeds
).sample
_, log_prob = ddim_step_with_logprob(
pipeline.scheduler,
noise_pred,
sample["timesteps"][:, j],
sample["latents"][:, j],
eta=config.sample.eta,
prev_sample=sample["next_latents"][:, j],
)
# ppo logic
advantages = torch.clamp(
sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max
)
ratio = torch.exp(log_prob - sample["log_probs"][:, j])
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range
)
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
# debugging values
info = {}
# John Schulman says that (ratio - 1) - log(ratio) is a better
# estimator, but most existing code uses this so...
# http://joschu.net/blog/kl-approx.html
info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range)
info["loss"] = loss
# backward pass
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if __name__ == "__main__":
app.run(main)

10
setup.py Normal file
View File

@ -0,0 +1,10 @@
from setuptools import setup, find_packages
setup(
name='ddpo-pytorch',
version='0.0.1',
packages=["ddpo_pytorch"],
install_requires=[
"ml-collections", "absl-py"
],
)