Initial commit
This commit is contained in:
		
							
								
								
									
										305
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,305 @@ | |||||||
|  | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||||
|  | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim | ||||||
|  |  | ||||||
|  | ### Intellij+all ### | ||||||
|  | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | ||||||
|  | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||||||
|  |  | ||||||
|  | # User-specific stuff | ||||||
|  | .idea/**/workspace.xml | ||||||
|  | .idea/**/tasks.xml | ||||||
|  | .idea/**/usage.statistics.xml | ||||||
|  | .idea/**/dictionaries | ||||||
|  | .idea/**/shelf | ||||||
|  |  | ||||||
|  | # AWS User-specific | ||||||
|  | .idea/**/aws.xml | ||||||
|  |  | ||||||
|  | # Generated files | ||||||
|  | .idea/**/contentModel.xml | ||||||
|  |  | ||||||
|  | # Sensitive or high-churn files | ||||||
|  | .idea/**/dataSources/ | ||||||
|  | .idea/**/dataSources.ids | ||||||
|  | .idea/**/dataSources.local.xml | ||||||
|  | .idea/**/sqlDataSources.xml | ||||||
|  | .idea/**/dynamic.xml | ||||||
|  | .idea/**/uiDesigner.xml | ||||||
|  | .idea/**/dbnavigator.xml | ||||||
|  |  | ||||||
|  | # Gradle | ||||||
|  | .idea/**/gradle.xml | ||||||
|  | .idea/**/libraries | ||||||
|  |  | ||||||
|  | # Gradle and Maven with auto-import | ||||||
|  | # When using Gradle or Maven with auto-import, you should exclude module files, | ||||||
|  | # since they will be recreated, and may cause churn.  Uncomment if using | ||||||
|  | # auto-import. | ||||||
|  | # .idea/artifacts | ||||||
|  | # .idea/compiler.xml | ||||||
|  | # .idea/jarRepositories.xml | ||||||
|  | # .idea/modules.xml | ||||||
|  | # .idea/*.iml | ||||||
|  | # .idea/modules | ||||||
|  | # *.iml | ||||||
|  | # *.ipr | ||||||
|  |  | ||||||
|  | # CMake | ||||||
|  | cmake-build-*/ | ||||||
|  |  | ||||||
|  | # Mongo Explorer plugin | ||||||
|  | .idea/**/mongoSettings.xml | ||||||
|  |  | ||||||
|  | # File-based project format | ||||||
|  | *.iws | ||||||
|  |  | ||||||
|  | # IntelliJ | ||||||
|  | out/ | ||||||
|  |  | ||||||
|  | # mpeltonen/sbt-idea plugin | ||||||
|  | .idea_modules/ | ||||||
|  |  | ||||||
|  | # JIRA plugin | ||||||
|  | atlassian-ide-plugin.xml | ||||||
|  |  | ||||||
|  | # Cursive Clojure plugin | ||||||
|  | .idea/replstate.xml | ||||||
|  |  | ||||||
|  | # SonarLint plugin | ||||||
|  | .idea/sonarlint/ | ||||||
|  |  | ||||||
|  | # Crashlytics plugin (for Android Studio and IntelliJ) | ||||||
|  | com_crashlytics_export_strings.xml | ||||||
|  | crashlytics.properties | ||||||
|  | crashlytics-build.properties | ||||||
|  | fabric.properties | ||||||
|  |  | ||||||
|  | # Editor-based Rest Client | ||||||
|  | .idea/httpRequests | ||||||
|  |  | ||||||
|  | # Android studio 3.1+ serialized cache file | ||||||
|  | .idea/caches/build_file_checksums.ser | ||||||
|  |  | ||||||
|  | ### Intellij+all Patch ### | ||||||
|  | # Ignore everything but code style settings and run configurations | ||||||
|  | # that are supposed to be shared within teams. | ||||||
|  |  | ||||||
|  | .idea/* | ||||||
|  |  | ||||||
|  | !.idea/codeStyles | ||||||
|  | !.idea/runConfigurations | ||||||
|  |  | ||||||
|  | ### Python ### | ||||||
|  | # Byte-compiled / optimized / DLL files | ||||||
|  | __pycache__/ | ||||||
|  | *.py[cod] | ||||||
|  | *$py.class | ||||||
|  |  | ||||||
|  | # C extensions | ||||||
|  | *.so | ||||||
|  |  | ||||||
|  | # Distribution / packaging | ||||||
|  | .Python | ||||||
|  | build/ | ||||||
|  | develop-eggs/ | ||||||
|  | dist/ | ||||||
|  | downloads/ | ||||||
|  | eggs/ | ||||||
|  | .eggs/ | ||||||
|  | lib/ | ||||||
|  | lib64/ | ||||||
|  | parts/ | ||||||
|  | sdist/ | ||||||
|  | var/ | ||||||
|  | wheels/ | ||||||
|  | share/python-wheels/ | ||||||
|  | *.egg-info/ | ||||||
|  | .installed.cfg | ||||||
|  | *.egg | ||||||
|  | MANIFEST | ||||||
|  |  | ||||||
|  | # PyInstaller | ||||||
|  | #  Usually these files are written by a python script from a template | ||||||
|  | #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||||
|  | *.manifest | ||||||
|  | *.spec | ||||||
|  |  | ||||||
|  | # Installer logs | ||||||
|  | pip-log.txt | ||||||
|  | pip-delete-this-directory.txt | ||||||
|  |  | ||||||
|  | # Unit test / coverage reports | ||||||
|  | htmlcov/ | ||||||
|  | .tox/ | ||||||
|  | .nox/ | ||||||
|  | .coverage | ||||||
|  | .coverage.* | ||||||
|  | .cache | ||||||
|  | nosetests.xml | ||||||
|  | coverage.xml | ||||||
|  | *.cover | ||||||
|  | *.py,cover | ||||||
|  | .hypothesis/ | ||||||
|  | .pytest_cache/ | ||||||
|  | cover/ | ||||||
|  |  | ||||||
|  | # Translations | ||||||
|  | *.mo | ||||||
|  | *.pot | ||||||
|  |  | ||||||
|  | # Django stuff: | ||||||
|  | *.log | ||||||
|  | local_settings.py | ||||||
|  | db.sqlite3 | ||||||
|  | db.sqlite3-journal | ||||||
|  |  | ||||||
|  | # Flask stuff: | ||||||
|  | instance/ | ||||||
|  | .webassets-cache | ||||||
|  |  | ||||||
|  | # Scrapy stuff: | ||||||
|  | .scrapy | ||||||
|  |  | ||||||
|  | # Sphinx documentation | ||||||
|  | docs/_build/ | ||||||
|  |  | ||||||
|  | # PyBuilder | ||||||
|  | .pybuilder/ | ||||||
|  | target/ | ||||||
|  |  | ||||||
|  | # Jupyter Notebook | ||||||
|  | .ipynb_checkpoints | ||||||
|  |  | ||||||
|  | # IPython | ||||||
|  | profile_default/ | ||||||
|  | ipython_config.py | ||||||
|  |  | ||||||
|  | # pyenv | ||||||
|  | #   For a library or package, you might want to ignore these files since the code is | ||||||
|  | #   intended to run in multiple environments; otherwise, check them in: | ||||||
|  | # .python-version | ||||||
|  |  | ||||||
|  | # pipenv | ||||||
|  | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||||
|  | #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||||
|  | #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||||
|  | #   install all needed dependencies. | ||||||
|  | #Pipfile.lock | ||||||
|  |  | ||||||
|  | # poetry | ||||||
|  | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||||||
|  | #   This is especially recommended for binary packages to ensure reproducibility, and is more | ||||||
|  | #   commonly ignored for libraries. | ||||||
|  | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||||||
|  | #poetry.lock | ||||||
|  |  | ||||||
|  | # pdm | ||||||
|  | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||||||
|  | #pdm.lock | ||||||
|  | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||||||
|  | #   in version control. | ||||||
|  | #   https://pdm.fming.dev/#use-with-ide | ||||||
|  | .pdm.toml | ||||||
|  |  | ||||||
|  | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||||||
|  | __pypackages__/ | ||||||
|  |  | ||||||
|  | # Celery stuff | ||||||
|  | celerybeat-schedule | ||||||
|  | celerybeat.pid | ||||||
|  |  | ||||||
|  | # SageMath parsed files | ||||||
|  | *.sage.py | ||||||
|  |  | ||||||
|  | # Environments | ||||||
|  | .env | ||||||
|  | .venv | ||||||
|  | env/ | ||||||
|  | venv/ | ||||||
|  | ENV/ | ||||||
|  | env.bak/ | ||||||
|  | venv.bak/ | ||||||
|  |  | ||||||
|  | # Spyder project settings | ||||||
|  | .spyderproject | ||||||
|  | .spyproject | ||||||
|  |  | ||||||
|  | # Rope project settings | ||||||
|  | .ropeproject | ||||||
|  |  | ||||||
|  | # mkdocs documentation | ||||||
|  | /site | ||||||
|  |  | ||||||
|  | # mypy | ||||||
|  | .mypy_cache/ | ||||||
|  | .dmypy.json | ||||||
|  | dmypy.json | ||||||
|  |  | ||||||
|  | # Pyre type checker | ||||||
|  | .pyre/ | ||||||
|  |  | ||||||
|  | # pytype static type analyzer | ||||||
|  | .pytype/ | ||||||
|  |  | ||||||
|  | # Cython debug symbols | ||||||
|  | cython_debug/ | ||||||
|  |  | ||||||
|  | # PyCharm | ||||||
|  | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||||||
|  | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||||||
|  | #  and can be added to the global gitignore or merged into this file.  For a more nuclear | ||||||
|  | #  option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||||||
|  | #.idea/ | ||||||
|  |  | ||||||
|  | ### Python Patch ### | ||||||
|  | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | ||||||
|  | poetry.toml | ||||||
|  |  | ||||||
|  | # ruff | ||||||
|  | .ruff_cache/ | ||||||
|  |  | ||||||
|  | # LSP config files | ||||||
|  | pyrightconfig.json | ||||||
|  |  | ||||||
|  | ### Vim ### | ||||||
|  | # Swap | ||||||
|  | [._]*.s[a-v][a-z] | ||||||
|  | !*.svg  # comment out if you don't need vector files | ||||||
|  | [._]*.sw[a-p] | ||||||
|  | [._]s[a-rt-v][a-z] | ||||||
|  | [._]ss[a-gi-z] | ||||||
|  | [._]sw[a-p] | ||||||
|  |  | ||||||
|  | # Session | ||||||
|  | Session.vim | ||||||
|  | Sessionx.vim | ||||||
|  |  | ||||||
|  | # Temporary | ||||||
|  | .netrwhist | ||||||
|  | *~ | ||||||
|  | # Auto-generated tag files | ||||||
|  | tags | ||||||
|  | # Persistent undo | ||||||
|  | [._]*.un~ | ||||||
|  |  | ||||||
|  | ### VisualStudioCode ### | ||||||
|  | .vscode/* | ||||||
|  | !.vscode/settings.json | ||||||
|  | !.vscode/tasks.json | ||||||
|  | !.vscode/launch.json | ||||||
|  | !.vscode/extensions.json | ||||||
|  | !.vscode/*.code-snippets | ||||||
|  |  | ||||||
|  | # Local History for Visual Studio Code | ||||||
|  | .history/ | ||||||
|  |  | ||||||
|  | # Built Visual Studio Code Extensions | ||||||
|  | *.vsix | ||||||
|  |  | ||||||
|  | ### VisualStudioCode Patch ### | ||||||
|  | # Ignore all local history of files | ||||||
|  | .history | ||||||
|  | .ionide | ||||||
|  |  | ||||||
|  | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||||
|  |  | ||||||
							
								
								
									
										
											BIN
										
									
								
								config/__pycache__/base.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								config/__pycache__/base.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										56
									
								
								config/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								config/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | import ml_collections | ||||||
|  |  | ||||||
|  | def get_config(): | ||||||
|  |  | ||||||
|  |     config = ml_collections.ConfigDict() | ||||||
|  |  | ||||||
|  |     # misc | ||||||
|  |     config.seed = 42 | ||||||
|  |     config.logdir = "logs" | ||||||
|  |     config.num_epochs = 100 | ||||||
|  |     config.mixed_precision = "fp16" | ||||||
|  |     config.allow_tf32 = True | ||||||
|  |  | ||||||
|  |     # pretrained model initialization | ||||||
|  |     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||||
|  |     pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||||
|  |     pretrained.revision = "main" | ||||||
|  |  | ||||||
|  |     # training | ||||||
|  |     config.train = train = ml_collections.ConfigDict() | ||||||
|  |     train.mixed_precision = "fp16" | ||||||
|  |     train.batch_size = 1 | ||||||
|  |     train.use_8bit_adam = False | ||||||
|  |     train.scale_lr = False | ||||||
|  |     train.learning_rate = 1e-4 | ||||||
|  |     train.adam_beta1 = 0.9 | ||||||
|  |     train.adam_beta2 = 0.999 | ||||||
|  |     train.adam_weight_decay = 1e-2 | ||||||
|  |     train.adam_epsilon = 1e-8 | ||||||
|  |     train.gradient_accumulation_steps = 1 | ||||||
|  |     train.max_grad_norm = 1.0 | ||||||
|  |     train.num_inner_epochs = 1 | ||||||
|  |     train.cfg = True | ||||||
|  |     train.adv_clip_max = 10 | ||||||
|  |     train.clip_range = 1e-4 | ||||||
|  |  | ||||||
|  |     # sampling | ||||||
|  |     config.sample = sample = ml_collections.ConfigDict() | ||||||
|  |     sample.num_steps = 5 | ||||||
|  |     sample.eta = 1.0 | ||||||
|  |     sample.guidance_scale = 5.0 | ||||||
|  |     sample.batch_size = 1 | ||||||
|  |     sample.num_batches_per_epoch = 4 | ||||||
|  |  | ||||||
|  |     # prompting | ||||||
|  |     config.prompt_fn = "imagenet_animals" | ||||||
|  |     config.prompt_fn_kwargs = {} | ||||||
|  |  | ||||||
|  |     # rewards | ||||||
|  |     config.reward_fn = "jpeg_compressibility" | ||||||
|  |  | ||||||
|  |     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||||
|  |     config.per_prompt_stat_tracking.buffer_size = 128 | ||||||
|  |     config.per_prompt_stat_tracking.min_count = 16 | ||||||
|  |  | ||||||
|  |     return config | ||||||
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | |||||||
|  | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | ||||||
|  | # with the following modifications: | ||||||
|  | # - | ||||||
|  |  | ||||||
|  | from typing import Optional, Tuple, Union | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from diffusers.utils import randn_tensor | ||||||
|  | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def ddim_step_with_logprob( | ||||||
|  |     self: DDIMScheduler, | ||||||
|  |     model_output: torch.FloatTensor, | ||||||
|  |     timestep: int, | ||||||
|  |     sample: torch.FloatTensor, | ||||||
|  |     eta: float = 0.0, | ||||||
|  |     use_clipped_model_output: bool = False, | ||||||
|  |     generator=None, | ||||||
|  |     prev_sample: Optional[torch.FloatTensor] = None, | ||||||
|  | ) -> Union[DDIMSchedulerOutput, Tuple]: | ||||||
|  |     """ | ||||||
|  |     Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||||||
|  |     process from the learned model outputs (most often the predicted noise). | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||||||
|  |         timestep (`int`): current discrete timestep in the diffusion chain. | ||||||
|  |         sample (`torch.FloatTensor`): | ||||||
|  |             current instance of sample being created by diffusion process. | ||||||
|  |         eta (`float`): weight of noise for added noise in diffusion step. | ||||||
|  |         use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped | ||||||
|  |             predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when | ||||||
|  |             `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would | ||||||
|  |             coincide with the one provided as input and `use_clipped_model_output` will have not effect. | ||||||
|  |         generator: random number generator. | ||||||
|  |         variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we | ||||||
|  |             can directly provide the noise for the variance itself. This is useful for methods such as | ||||||
|  |             CycleDiffusion. (https://arxiv.org/abs/2210.05559) | ||||||
|  |         return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: | ||||||
|  |         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||||||
|  |         returning a tuple, the first element is the sample tensor. | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     assert isinstance(self, DDIMScheduler) | ||||||
|  |     if self.num_inference_steps is None: | ||||||
|  |         raise ValueError( | ||||||
|  |             "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | ||||||
|  |     # Ideally, read DDIM paper in-detail understanding | ||||||
|  |  | ||||||
|  |     # Notation (<variable name> -> <name in paper> | ||||||
|  |     # - pred_noise_t -> e_theta(x_t, t) | ||||||
|  |     # - pred_original_sample -> f_theta(x_t, t) or x_0 | ||||||
|  |     # - std_dev_t -> sigma_t | ||||||
|  |     # - eta -> η | ||||||
|  |     # - pred_sample_direction -> "direction pointing to x_t" | ||||||
|  |     # - pred_prev_sample -> "x_t-1" | ||||||
|  |  | ||||||
|  |     # 1. get previous step value (=t-1) | ||||||
|  |     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||||
|  |  | ||||||
|  |     # 2. compute alphas, betas | ||||||
|  |     self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) | ||||||
|  |     self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) | ||||||
|  |     alpha_prod_t = self.alphas_cumprod.gather(0, timestep) | ||||||
|  |     alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) | ||||||
|  |     print(timestep) | ||||||
|  |     print(alpha_prod_t) | ||||||
|  |     print(alpha_prod_t_prev) | ||||||
|  |     print(prev_timestep) | ||||||
|  |  | ||||||
|  |     beta_prod_t = 1 - alpha_prod_t | ||||||
|  |  | ||||||
|  |     # 3. compute predicted original sample from predicted noise also called | ||||||
|  |     # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||
|  |     if self.config.prediction_type == "epsilon": | ||||||
|  |         pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||||||
|  |         pred_epsilon = model_output | ||||||
|  |     elif self.config.prediction_type == "sample": | ||||||
|  |         pred_original_sample = model_output | ||||||
|  |         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||||
|  |     elif self.config.prediction_type == "v_prediction": | ||||||
|  |         pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||||||
|  |         pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||||||
|  |     else: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | ||||||
|  |             " `v_prediction`" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # 4. Clip or threshold "predicted x_0" | ||||||
|  |     if self.config.thresholding: | ||||||
|  |         pred_original_sample = self._threshold_sample(pred_original_sample) | ||||||
|  |     elif self.config.clip_sample: | ||||||
|  |         pred_original_sample = pred_original_sample.clamp( | ||||||
|  |             -self.config.clip_sample_range, self.config.clip_sample_range | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # 5. compute variance: "sigma_t(η)" -> see formula (16) | ||||||
|  |     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||||
|  |     variance = self._get_variance(timestep, prev_timestep) | ||||||
|  |     std_dev_t = eta * variance ** (0.5) | ||||||
|  |  | ||||||
|  |     if use_clipped_model_output: | ||||||
|  |         # the pred_epsilon is always re-derived from the clipped x_0 in Glide | ||||||
|  |         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||||
|  |  | ||||||
|  |     # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||
|  |     pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | ||||||
|  |  | ||||||
|  |     # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||
|  |     prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | ||||||
|  |  | ||||||
|  |     if prev_sample is not None and generator is not None: | ||||||
|  |         raise ValueError( | ||||||
|  |             "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" | ||||||
|  |             " `prev_sample` stays `None`." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     if prev_sample is None: | ||||||
|  |         variance_noise = randn_tensor( | ||||||
|  |             model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | ||||||
|  |         ) | ||||||
|  |         prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||||||
|  |  | ||||||
|  |     # log prob of prev_sample given prev_sample_mean and std_dev_t | ||||||
|  |     log_prob = ( | ||||||
|  |         -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) | ||||||
|  |         - torch.log(std_dev_t) | ||||||
|  |         - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||||||
|  |     ) | ||||||
|  |     # mean along all but batch dimension | ||||||
|  |     log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||||
|  |  | ||||||
|  |     return prev_sample, log_prob | ||||||
							
								
								
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | |||||||
|  | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py | ||||||
|  | # with the following modifications: | ||||||
|  | # - | ||||||
|  |  | ||||||
|  | from typing import Any, Callable, Dict, List, Optional, Union | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( | ||||||
|  |     StableDiffusionPipeline, | ||||||
|  |     rescale_noise_cfg, | ||||||
|  | ) | ||||||
|  | from diffusers.schedulers.scheduling_ddim import DDIMScheduler | ||||||
|  | from .ddim_with_logprob import ddim_step_with_logprob | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @torch.no_grad() | ||||||
|  | def pipeline_with_logprob( | ||||||
|  |     self: StableDiffusionPipeline, | ||||||
|  |     prompt: Union[str, List[str]] = None, | ||||||
|  |     height: Optional[int] = None, | ||||||
|  |     width: Optional[int] = None, | ||||||
|  |     num_inference_steps: int = 50, | ||||||
|  |     guidance_scale: float = 7.5, | ||||||
|  |     negative_prompt: Optional[Union[str, List[str]]] = None, | ||||||
|  |     num_images_per_prompt: Optional[int] = 1, | ||||||
|  |     eta: float = 0.0, | ||||||
|  |     generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||||||
|  |     latents: Optional[torch.FloatTensor] = None, | ||||||
|  |     prompt_embeds: Optional[torch.FloatTensor] = None, | ||||||
|  |     negative_prompt_embeds: Optional[torch.FloatTensor] = None, | ||||||
|  |     output_type: Optional[str] = "pil", | ||||||
|  |     return_dict: bool = True, | ||||||
|  |     callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||||||
|  |     callback_steps: int = 1, | ||||||
|  |     cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||||||
|  |     guidance_rescale: float = 0.0, | ||||||
|  | ): | ||||||
|  |     r""" | ||||||
|  |     Function invoked when calling the pipeline for generation. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         prompt (`str` or `List[str]`, *optional*): | ||||||
|  |             The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | ||||||
|  |             instead. | ||||||
|  |         height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||||
|  |             The height in pixels of the generated image. | ||||||
|  |         width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||||
|  |             The width in pixels of the generated image. | ||||||
|  |         num_inference_steps (`int`, *optional*, defaults to 50): | ||||||
|  |             The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||||||
|  |             expense of slower inference. | ||||||
|  |         guidance_scale (`float`, *optional*, defaults to 7.5): | ||||||
|  |             Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||||||
|  |             `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||||||
|  |             Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||||||
|  |             1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||||||
|  |             usually at the expense of lower image quality. | ||||||
|  |         negative_prompt (`str` or `List[str]`, *optional*): | ||||||
|  |             The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||||||
|  |             `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | ||||||
|  |             less than `1`). | ||||||
|  |         num_images_per_prompt (`int`, *optional*, defaults to 1): | ||||||
|  |             The number of images to generate per prompt. | ||||||
|  |         eta (`float`, *optional*, defaults to 0.0): | ||||||
|  |             Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||||||
|  |             [`schedulers.DDIMScheduler`], will be ignored for others. | ||||||
|  |         generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | ||||||
|  |             One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | ||||||
|  |             to make generation deterministic. | ||||||
|  |         latents (`torch.FloatTensor`, *optional*): | ||||||
|  |             Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||||||
|  |             generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||||||
|  |             tensor will ge generated by sampling using the supplied random `generator`. | ||||||
|  |         prompt_embeds (`torch.FloatTensor`, *optional*): | ||||||
|  |             Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | ||||||
|  |             provided, text embeddings will be generated from `prompt` input argument. | ||||||
|  |         negative_prompt_embeds (`torch.FloatTensor`, *optional*): | ||||||
|  |             Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||||||
|  |             weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | ||||||
|  |             argument. | ||||||
|  |         output_type (`str`, *optional*, defaults to `"pil"`): | ||||||
|  |             The output format of the generate image. Choose between | ||||||
|  |             [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||||||
|  |         return_dict (`bool`, *optional*, defaults to `True`): | ||||||
|  |             Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||||||
|  |             plain tuple. | ||||||
|  |         callback (`Callable`, *optional*): | ||||||
|  |             A function that will be called every `callback_steps` steps during inference. The function will be | ||||||
|  |             called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | ||||||
|  |         callback_steps (`int`, *optional*, defaults to 1): | ||||||
|  |             The frequency at which the `callback` function will be called. If not specified, the callback will be | ||||||
|  |             called at every step. | ||||||
|  |         cross_attention_kwargs (`dict`, *optional*): | ||||||
|  |             A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||||||
|  |             `self.processor` in | ||||||
|  |             [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||||||
|  |         guidance_rescale (`float`, *optional*, defaults to 0.7): | ||||||
|  |             Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are | ||||||
|  |             Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of | ||||||
|  |             [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). | ||||||
|  |             Guidance rescale factor should fix overexposure when using zero terminal SNR. | ||||||
|  |  | ||||||
|  |     Examples: | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||||||
|  |         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||||||
|  |         When returning a tuple, the first element is a list with the generated images, and the second element is a | ||||||
|  |         list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||||||
|  |         (nsfw) content, according to the `safety_checker`. | ||||||
|  |     """ | ||||||
|  |     # 0. Default height and width to unet | ||||||
|  |     height = height or self.unet.config.sample_size * self.vae_scale_factor | ||||||
|  |     width = width or self.unet.config.sample_size * self.vae_scale_factor | ||||||
|  |  | ||||||
|  |     # 1. Check inputs. Raise error if not correct | ||||||
|  |     self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) | ||||||
|  |  | ||||||
|  |     # 2. Define call parameters | ||||||
|  |     if prompt is not None and isinstance(prompt, str): | ||||||
|  |         batch_size = 1 | ||||||
|  |     elif prompt is not None and isinstance(prompt, list): | ||||||
|  |         batch_size = len(prompt) | ||||||
|  |     else: | ||||||
|  |         batch_size = prompt_embeds.shape[0] | ||||||
|  |  | ||||||
|  |     device = self._execution_device | ||||||
|  |     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||||||
|  |     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||||||
|  |     # corresponds to doing no classifier free guidance. | ||||||
|  |     do_classifier_free_guidance = guidance_scale > 1.0 | ||||||
|  |  | ||||||
|  |     # 3. Encode input prompt | ||||||
|  |     text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | ||||||
|  |     prompt_embeds = self._encode_prompt( | ||||||
|  |         prompt, | ||||||
|  |         device, | ||||||
|  |         num_images_per_prompt, | ||||||
|  |         do_classifier_free_guidance, | ||||||
|  |         negative_prompt, | ||||||
|  |         prompt_embeds=prompt_embeds, | ||||||
|  |         negative_prompt_embeds=negative_prompt_embeds, | ||||||
|  |         lora_scale=text_encoder_lora_scale, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 4. Prepare timesteps | ||||||
|  |     self.scheduler.set_timesteps(num_inference_steps, device=device) | ||||||
|  |     timesteps = self.scheduler.timesteps | ||||||
|  |  | ||||||
|  |     # 5. Prepare latent variables | ||||||
|  |     num_channels_latents = self.unet.config.in_channels | ||||||
|  |     latents = self.prepare_latents( | ||||||
|  |         batch_size * num_images_per_prompt, | ||||||
|  |         num_channels_latents, | ||||||
|  |         height, | ||||||
|  |         width, | ||||||
|  |         prompt_embeds.dtype, | ||||||
|  |         device, | ||||||
|  |         generator, | ||||||
|  |         latents, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | ||||||
|  |     extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | ||||||
|  |  | ||||||
|  |     # 7. Denoising loop | ||||||
|  |     num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||||||
|  |     all_latents = [latents] | ||||||
|  |     all_log_probs = [] | ||||||
|  |     with self.progress_bar(total=num_inference_steps) as progress_bar: | ||||||
|  |         for i, t in enumerate(timesteps): | ||||||
|  |             # expand the latents if we are doing classifier free guidance | ||||||
|  |             latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||||||
|  |             latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||||||
|  |  | ||||||
|  |             # predict the noise residual | ||||||
|  |             noise_pred = self.unet( | ||||||
|  |                 latent_model_input, | ||||||
|  |                 t, | ||||||
|  |                 encoder_hidden_states=prompt_embeds, | ||||||
|  |                 cross_attention_kwargs=cross_attention_kwargs, | ||||||
|  |                 return_dict=False, | ||||||
|  |             )[0] | ||||||
|  |  | ||||||
|  |             # perform guidance | ||||||
|  |             if do_classifier_free_guidance: | ||||||
|  |                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||||
|  |                 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||||||
|  |  | ||||||
|  |             if do_classifier_free_guidance and guidance_rescale > 0.0: | ||||||
|  |                 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||||||
|  |                 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | ||||||
|  |  | ||||||
|  |             # compute the previous noisy sample x_t -> x_t-1 | ||||||
|  |             latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) | ||||||
|  |  | ||||||
|  |             all_latents.append(latents) | ||||||
|  |             all_log_probs.append(log_prob) | ||||||
|  |  | ||||||
|  |             # call the callback, if provided | ||||||
|  |             if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||||||
|  |                 progress_bar.update() | ||||||
|  |                 if callback is not None and i % callback_steps == 0: | ||||||
|  |                     callback(i, t, latents) | ||||||
|  |  | ||||||
|  |     if not output_type == "latent": | ||||||
|  |         image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||||||
|  |         image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||||||
|  |     else: | ||||||
|  |         image = latents | ||||||
|  |         has_nsfw_concept = None | ||||||
|  |  | ||||||
|  |     if has_nsfw_concept is None: | ||||||
|  |         do_denormalize = [True] * image.shape[0] | ||||||
|  |     else: | ||||||
|  |         do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | ||||||
|  |  | ||||||
|  |     image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||||||
|  |  | ||||||
|  |     # Offload last model to CPU | ||||||
|  |     if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||||||
|  |         self.final_offload_hook.offload() | ||||||
|  |  | ||||||
|  |     return image, has_nsfw_concept, all_latents, all_log_probs | ||||||
							
								
								
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | |||||||
|  | from importlib import resources | ||||||
|  | import functools | ||||||
|  | import random | ||||||
|  | import inflect | ||||||
|  |  | ||||||
|  | IE = inflect.engine() | ||||||
|  | ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @functools.cache | ||||||
|  | def load_lines(name): | ||||||
|  |     with ASSETS_PATH.joinpath(name).open() as f: | ||||||
|  |         return [line.strip() for line in f.readlines()] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def imagenet(low, high): | ||||||
|  |     return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def imagenet_all(): | ||||||
|  |     return imagenet(0, 1000) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def imagenet_animals(): | ||||||
|  |     return imagenet(0, 398) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def imagenet_dogs(): | ||||||
|  |     return imagenet(151, 269) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def nouns_activities(nouns_file, activities_file): | ||||||
|  |     nouns = load_lines(nouns_file) | ||||||
|  |     activities = load_lines(activities_file) | ||||||
|  |     return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def counting(nouns_file, low, high): | ||||||
|  |     nouns = load_lines(nouns_file) | ||||||
|  |     number = IE.number_to_words(random.randint(low, high)) | ||||||
|  |     noun = random.choice(nouns) | ||||||
|  |     plural_noun = IE.plural(noun) | ||||||
|  |     prompt = f"{number} {plural_noun}" | ||||||
|  |     metadata = { | ||||||
|  |         "questions": [ | ||||||
|  |             f"How many {plural_noun} are there in this image?", | ||||||
|  |             f"What animal is in this image?", | ||||||
|  |         ], | ||||||
|  |         "answers": [ | ||||||
|  |             number, | ||||||
|  |             noun, | ||||||
|  |         ], | ||||||
|  |     } | ||||||
|  |     return prompt, metadata | ||||||
							
								
								
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | from PIL import Image | ||||||
|  | import io | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def jpeg_incompressibility(): | ||||||
|  |     def _fn(images, prompts, metadata): | ||||||
|  |         if isinstance(images, torch.Tensor): | ||||||
|  |             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||||
|  |             images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC | ||||||
|  |         images = [Image.fromarray(image) for image in images] | ||||||
|  |         buffers = [io.BytesIO() for _ in images] | ||||||
|  |         for image, buffer in zip(images, buffers): | ||||||
|  |             image.save(buffer, format="JPEG", quality=95) | ||||||
|  |         sizes = [buffer.tell() / 1000 for buffer in buffers] | ||||||
|  |         return np.array(sizes), {} | ||||||
|  |  | ||||||
|  |     return _fn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def jpeg_compressibility(): | ||||||
|  |     jpeg_fn = jpeg_incompressibility() | ||||||
|  |  | ||||||
|  |     def _fn(images, prompts, metadata): | ||||||
|  |         rew, meta = jpeg_fn(images, prompts, metadata) | ||||||
|  |         return -rew, meta | ||||||
|  |  | ||||||
|  |     return _fn | ||||||
							
								
								
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | import numpy as np | ||||||
|  | from collections import deque | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PerPromptStatTracker: | ||||||
|  |     def __init__(self, buffer_size, min_count): | ||||||
|  |         self.buffer_size = buffer_size | ||||||
|  |         self.min_count = min_count | ||||||
|  |         self.stats = {} | ||||||
|  |  | ||||||
|  |     def update(self, prompts, rewards): | ||||||
|  |         unique = np.unique(prompts) | ||||||
|  |         advantages = np.empty_like(rewards) | ||||||
|  |         for prompt in unique: | ||||||
|  |             prompt_rewards = rewards[prompts == prompt] | ||||||
|  |             if prompt not in self.stats: | ||||||
|  |                 self.stats[prompt] = deque(maxlen=self.buffer_size) | ||||||
|  |             self.stats[prompt].extend(prompt_rewards) | ||||||
|  |  | ||||||
|  |             if len(self.stats[prompt]) < self.min_count: | ||||||
|  |                 mean = np.mean(rewards) | ||||||
|  |                 std = np.std(rewards) + 1e-6 | ||||||
|  |             else: | ||||||
|  |                 mean = np.mean(self.stats[prompt]) | ||||||
|  |                 std = np.std(self.stats[prompt]) + 1e-6 | ||||||
|  |             advantages[prompts == prompt] = (prompt_rewards - mean) / std | ||||||
|  |  | ||||||
|  |         return advantages | ||||||
|  |  | ||||||
|  |     def get_stats(self): | ||||||
|  |         return { | ||||||
|  |             k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} | ||||||
|  |             for k, v in self.stats.items() | ||||||
|  |         } | ||||||
							
								
								
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,341 @@ | |||||||
|  | from absl import app, flags, logging | ||||||
|  | from ml_collections import config_flags | ||||||
|  | from accelerate import Accelerator | ||||||
|  | from accelerate.utils import set_seed | ||||||
|  | from accelerate.logging import get_logger | ||||||
|  | from diffusers import StableDiffusionPipeline, DDIMScheduler | ||||||
|  | from diffusers.loaders import AttnProcsLayers | ||||||
|  | from diffusers.models.attention_processor import LoRAAttnProcessor | ||||||
|  | import ddpo_pytorch.prompts | ||||||
|  | import ddpo_pytorch.rewards | ||||||
|  | from ddpo_pytorch.stat_tracking import PerPromptStatTracker | ||||||
|  | from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob | ||||||
|  | from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob | ||||||
|  | import torch | ||||||
|  | import tqdm | ||||||
|  |  | ||||||
|  |  | ||||||
|  | FLAGS = flags.FLAGS | ||||||
|  | config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") | ||||||
|  |  | ||||||
|  | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(_): | ||||||
|  |     # basic Accelerate and logging setup | ||||||
|  |     config = FLAGS.config | ||||||
|  |     accelerator = Accelerator( | ||||||
|  |         log_with="all", | ||||||
|  |         mixed_precision=config.mixed_precision, | ||||||
|  |         project_dir=config.logdir, | ||||||
|  |     ) | ||||||
|  |     if accelerator.is_main_process: | ||||||
|  |         accelerator.init_trackers(project_name="ddpo-pytorch", config=config) | ||||||
|  |     logger.info(config) | ||||||
|  |  | ||||||
|  |     # set seed | ||||||
|  |     set_seed(config.seed) | ||||||
|  |  | ||||||
|  |     # load scheduler, tokenizer and models. | ||||||
|  |     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) | ||||||
|  |     # freeze parameters of models to save more memory | ||||||
|  |     pipeline.unet.requires_grad_(False) | ||||||
|  |     pipeline.vae.requires_grad_(False) | ||||||
|  |     pipeline.text_encoder.requires_grad_(False) | ||||||
|  |     # disable safety checker | ||||||
|  |     pipeline.safety_checker = None | ||||||
|  |     # make the progress bar nicer | ||||||
|  |     pipeline.set_progress_bar_config( | ||||||
|  |         position=1, | ||||||
|  |         disable=not accelerator.is_local_main_process, | ||||||
|  |         leave=False, | ||||||
|  |     ) | ||||||
|  |     # switch to DDIM scheduler | ||||||
|  |     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||||||
|  |  | ||||||
|  |     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | ||||||
|  |     # as these weights are only used for inference, keeping weights in full precision is not required. | ||||||
|  |     weight_dtype = torch.float32 | ||||||
|  |     if accelerator.mixed_precision == "fp16": | ||||||
|  |         weight_dtype = torch.float16 | ||||||
|  |     elif accelerator.mixed_precision == "bf16": | ||||||
|  |         weight_dtype = torch.bfloat16 | ||||||
|  |  | ||||||
|  |     # Move unet, vae and text_encoder to device and cast to weight_dtype | ||||||
|  |     pipeline.unet.to(accelerator.device, dtype=weight_dtype) | ||||||
|  |     pipeline.vae.to(accelerator.device, dtype=weight_dtype) | ||||||
|  |     pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) | ||||||
|  |  | ||||||
|  |     # Set correct lora layers | ||||||
|  |     lora_attn_procs = {} | ||||||
|  |     for name in pipeline.unet.attn_processors.keys(): | ||||||
|  |         cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||||
|  |         if name.startswith("mid_block"): | ||||||
|  |             hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||||
|  |         elif name.startswith("up_blocks"): | ||||||
|  |             block_id = int(name[len("up_blocks.")]) | ||||||
|  |             hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||||
|  |         elif name.startswith("down_blocks"): | ||||||
|  |             block_id = int(name[len("down_blocks.")]) | ||||||
|  |             hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||||
|  |  | ||||||
|  |         lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||||
|  |  | ||||||
|  |     pipeline.unet.set_attn_processor(lora_attn_procs) | ||||||
|  |     lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||||
|  |  | ||||||
|  |     # Enable TF32 for faster training on Ampere GPUs, | ||||||
|  |     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||||
|  |     if config.allow_tf32: | ||||||
|  |         torch.backends.cuda.matmul.allow_tf32 = True | ||||||
|  |  | ||||||
|  |     if config.train.scale_lr: | ||||||
|  |         config.train.learning_rate = ( | ||||||
|  |             config.train.learning_rate | ||||||
|  |             * config.train.gradient_accumulation_steps | ||||||
|  |             * config.train.batch_size | ||||||
|  |             * accelerator.num_processes | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     # Initialize the optimizer | ||||||
|  |     if config.train.use_8bit_adam: | ||||||
|  |         try: | ||||||
|  |             import bitsandbytes as bnb | ||||||
|  |         except ImportError: | ||||||
|  |             raise ImportError( | ||||||
|  |                 "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         optimizer_cls = bnb.optim.AdamW8bit | ||||||
|  |     else: | ||||||
|  |         optimizer_cls = torch.optim.AdamW | ||||||
|  |  | ||||||
|  |     optimizer = optimizer_cls( | ||||||
|  |         lora_layers.parameters(), | ||||||
|  |         lr=config.train.learning_rate, | ||||||
|  |         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||||
|  |         weight_decay=config.train.adam_weight_decay, | ||||||
|  |         eps=config.train.adam_epsilon, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # prepare prompt and reward fn | ||||||
|  |     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) | ||||||
|  |     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() | ||||||
|  |  | ||||||
|  |     # Prepare everything with our `accelerator`. | ||||||
|  |     lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) | ||||||
|  |  | ||||||
|  |     # Train! | ||||||
|  |     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||||
|  |     total_train_batch_size = ( | ||||||
|  |         config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     assert config.sample.batch_size % config.train.batch_size == 0 | ||||||
|  |     assert samples_per_epoch % total_train_batch_size == 0 | ||||||
|  |  | ||||||
|  |     logger.info("***** Running training *****") | ||||||
|  |     logger.info(f"  Num Epochs = {config.num_epochs}") | ||||||
|  |     logger.info(f"  Sample batch size per device = {config.sample.batch_size}") | ||||||
|  |     logger.info(f"  Train batch size per device = {config.train.batch_size}") | ||||||
|  |     logger.info(f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") | ||||||
|  |     logger.info("") | ||||||
|  |     logger.info(f"  Total number of samples per epoch = {samples_per_epoch}") | ||||||
|  |     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") | ||||||
|  |     logger.info(f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") | ||||||
|  |     logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}") | ||||||
|  |  | ||||||
|  |     neg_prompt_embed = pipeline.text_encoder( | ||||||
|  |         pipeline.tokenizer( | ||||||
|  |             [""], | ||||||
|  |             return_tensors="pt", | ||||||
|  |             padding="max_length", | ||||||
|  |             truncation=True, | ||||||
|  |             max_length=pipeline.tokenizer.model_max_length, | ||||||
|  |         ).input_ids.to(accelerator.device) | ||||||
|  |     )[0] | ||||||
|  |     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||||
|  |     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||||
|  |  | ||||||
|  |     if config.per_prompt_stat_tracking: | ||||||
|  |         stat_tracker = PerPromptStatTracker( | ||||||
|  |             config.per_prompt_stat_tracking.buffer_size, | ||||||
|  |             config.per_prompt_stat_tracking.min_count, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     for epoch in range(config.num_epochs): | ||||||
|  |         #################### SAMPLING #################### | ||||||
|  |         samples = [] | ||||||
|  |         prompts = [] | ||||||
|  |         for i in tqdm.tqdm( | ||||||
|  |             range(config.sample.num_batches_per_epoch), | ||||||
|  |             desc=f"Epoch {epoch}: sampling", | ||||||
|  |             disable=not accelerator.is_local_main_process, | ||||||
|  |             position=0, | ||||||
|  |         ): | ||||||
|  |             # generate prompts | ||||||
|  |             prompts, prompt_metadata = zip( | ||||||
|  |                 *[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)] | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             # encode prompts | ||||||
|  |             prompt_ids = pipeline.tokenizer( | ||||||
|  |                 prompts, | ||||||
|  |                 return_tensors="pt", | ||||||
|  |                 padding="max_length", | ||||||
|  |                 truncation=True, | ||||||
|  |                 max_length=pipeline.tokenizer.model_max_length, | ||||||
|  |             ).input_ids.to(accelerator.device) | ||||||
|  |             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] | ||||||
|  |  | ||||||
|  |             # sample | ||||||
|  |             pipeline.unet.eval() | ||||||
|  |             pipeline.vae.eval() | ||||||
|  |             images, _, latents, log_probs = pipeline_with_logprob( | ||||||
|  |                 pipeline, | ||||||
|  |                 prompt_embeds=prompt_embeds, | ||||||
|  |                 negative_prompt_embeds=sample_neg_prompt_embeds, | ||||||
|  |                 num_inference_steps=config.sample.num_steps, | ||||||
|  |                 guidance_scale=config.sample.guidance_scale, | ||||||
|  |                 eta=config.sample.eta, | ||||||
|  |                 output_type="pt", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) | ||||||
|  |             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||||
|  |             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||||
|  |  | ||||||
|  |             # compute rewards | ||||||
|  |             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) | ||||||
|  |  | ||||||
|  |             samples.append( | ||||||
|  |                 { | ||||||
|  |                     "prompt_ids": prompt_ids, | ||||||
|  |                     "prompt_embeds": prompt_embeds, | ||||||
|  |                     "timesteps": timesteps, | ||||||
|  |                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||||
|  |                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||||
|  |                     "log_probs": log_probs, | ||||||
|  |                     "rewards": torch.as_tensor(rewards), | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||||
|  |         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||||
|  |  | ||||||
|  |         # gather rewards across processes | ||||||
|  |         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||||
|  |  | ||||||
|  |         # per-prompt mean/std tracking | ||||||
|  |         if config.per_prompt_stat_tracking: | ||||||
|  |             # gather the prompts across processes | ||||||
|  |             prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() | ||||||
|  |             prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) | ||||||
|  |             advantages = stat_tracker.update(prompts, rewards) | ||||||
|  |         else: | ||||||
|  |             advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | ||||||
|  |  | ||||||
|  |         # ungather advantages; we only need to keep the entries corresponding to the samples on this process | ||||||
|  |         samples["advantages"] = ( | ||||||
|  |             torch.as_tensor(advantages) | ||||||
|  |             .reshape(accelerator.num_processes, -1)[accelerator.process_index] | ||||||
|  |             .to(accelerator.device) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         del samples["rewards"] | ||||||
|  |         del samples["prompt_ids"] | ||||||
|  |  | ||||||
|  |         total_batch_size, num_timesteps = samples["timesteps"].shape | ||||||
|  |         assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch | ||||||
|  |         assert num_timesteps == config.sample.num_steps | ||||||
|  |  | ||||||
|  |         #################### TRAINING #################### | ||||||
|  |         for inner_epoch in range(config.train.num_inner_epochs): | ||||||
|  |             # shuffle samples along batch dimension | ||||||
|  |             indices = torch.randperm(total_batch_size, device=accelerator.device) | ||||||
|  |             samples = {k: v[indices] for k, v in samples.items()} | ||||||
|  |  | ||||||
|  |             # shuffle along time dimension, independently for each sample | ||||||
|  |             for i in range(total_batch_size): | ||||||
|  |                 indices = torch.randperm(num_timesteps, device=accelerator.device) | ||||||
|  |                 for key in ["timesteps", "latents", "next_latents"]: | ||||||
|  |                     samples[key][i] = samples[key][i][indices] | ||||||
|  |  | ||||||
|  |             # rebatch for training | ||||||
|  |             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||||
|  |  | ||||||
|  |             # dict of lists -> list of dicts for easier iteration | ||||||
|  |             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||||
|  |  | ||||||
|  |             # train | ||||||
|  |             for i, sample in tqdm.tqdm( | ||||||
|  |                 list(enumerate(samples_batched)), | ||||||
|  |                 desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", | ||||||
|  |                 position=0, | ||||||
|  |             ): | ||||||
|  |                 if config.train.cfg: | ||||||
|  |                     # concat negative prompts to sample prompts to avoid two forward passes | ||||||
|  |                     embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]]) | ||||||
|  |                 else: | ||||||
|  |                     embeds = sample["prompt_embeds"] | ||||||
|  |  | ||||||
|  |                 for j in tqdm.trange( | ||||||
|  |                     num_timesteps, | ||||||
|  |                     desc=f"Timestep", | ||||||
|  |                     position=1, | ||||||
|  |                     leave=False, | ||||||
|  |                 ): | ||||||
|  |                     with accelerator.accumulate(pipeline.unet): | ||||||
|  |                         if config.train.cfg: | ||||||
|  |                             noise_pred = pipeline.unet( | ||||||
|  |                                 torch.cat([sample["latents"][:, j]] * 2), | ||||||
|  |                                 torch.cat([sample["timesteps"][:, j]] * 2), | ||||||
|  |                                 embeds, | ||||||
|  |                             ).sample | ||||||
|  |                             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||||
|  |                             noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||||
|  |                                 noise_pred_text - noise_pred_uncond | ||||||
|  |                             ) | ||||||
|  |                         else: | ||||||
|  |                             noise_pred = pipeline.unet( | ||||||
|  |                                 sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||||
|  |                             ).sample | ||||||
|  |                         _, log_prob = ddim_step_with_logprob( | ||||||
|  |                             pipeline.scheduler, | ||||||
|  |                             noise_pred, | ||||||
|  |                             sample["timesteps"][:, j], | ||||||
|  |                             sample["latents"][:, j], | ||||||
|  |                             eta=config.sample.eta, | ||||||
|  |                             prev_sample=sample["next_latents"][:, j], | ||||||
|  |                         ) | ||||||
|  |  | ||||||
|  |                         # ppo logic | ||||||
|  |                         advantages = torch.clamp( | ||||||
|  |                             sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max | ||||||
|  |                         ) | ||||||
|  |                         ratio = torch.exp(log_prob - sample["log_probs"][:, j]) | ||||||
|  |                         unclipped_loss = -advantages * ratio | ||||||
|  |                         clipped_loss = -advantages * torch.clamp( | ||||||
|  |                             ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range | ||||||
|  |                         ) | ||||||
|  |                         loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | ||||||
|  |  | ||||||
|  |                         # debugging values | ||||||
|  |                         info = {} | ||||||
|  |                         # John Schulman says that (ratio - 1) - log(ratio) is a better | ||||||
|  |                         # estimator, but most existing code uses this so... | ||||||
|  |                         # http://joschu.net/blog/kl-approx.html | ||||||
|  |                         info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) | ||||||
|  |                         info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) | ||||||
|  |                         info["loss"] = loss | ||||||
|  |  | ||||||
|  |                         # backward pass | ||||||
|  |                         accelerator.backward(loss) | ||||||
|  |                         if accelerator.sync_gradients: | ||||||
|  |                             accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) | ||||||
|  |                         optimizer.step() | ||||||
|  |                         optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     app.run(main) | ||||||
		Reference in New Issue
	
	Block a user