Initial commit
This commit is contained in:
		
							
								
								
									
										305
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,305 @@ | ||||
| # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||
| # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim | ||||
|  | ||||
| ### Intellij+all ### | ||||
| # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | ||||
| # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||||
|  | ||||
| # User-specific stuff | ||||
| .idea/**/workspace.xml | ||||
| .idea/**/tasks.xml | ||||
| .idea/**/usage.statistics.xml | ||||
| .idea/**/dictionaries | ||||
| .idea/**/shelf | ||||
|  | ||||
| # AWS User-specific | ||||
| .idea/**/aws.xml | ||||
|  | ||||
| # Generated files | ||||
| .idea/**/contentModel.xml | ||||
|  | ||||
| # Sensitive or high-churn files | ||||
| .idea/**/dataSources/ | ||||
| .idea/**/dataSources.ids | ||||
| .idea/**/dataSources.local.xml | ||||
| .idea/**/sqlDataSources.xml | ||||
| .idea/**/dynamic.xml | ||||
| .idea/**/uiDesigner.xml | ||||
| .idea/**/dbnavigator.xml | ||||
|  | ||||
| # Gradle | ||||
| .idea/**/gradle.xml | ||||
| .idea/**/libraries | ||||
|  | ||||
| # Gradle and Maven with auto-import | ||||
| # When using Gradle or Maven with auto-import, you should exclude module files, | ||||
| # since they will be recreated, and may cause churn.  Uncomment if using | ||||
| # auto-import. | ||||
| # .idea/artifacts | ||||
| # .idea/compiler.xml | ||||
| # .idea/jarRepositories.xml | ||||
| # .idea/modules.xml | ||||
| # .idea/*.iml | ||||
| # .idea/modules | ||||
| # *.iml | ||||
| # *.ipr | ||||
|  | ||||
| # CMake | ||||
| cmake-build-*/ | ||||
|  | ||||
| # Mongo Explorer plugin | ||||
| .idea/**/mongoSettings.xml | ||||
|  | ||||
| # File-based project format | ||||
| *.iws | ||||
|  | ||||
| # IntelliJ | ||||
| out/ | ||||
|  | ||||
| # mpeltonen/sbt-idea plugin | ||||
| .idea_modules/ | ||||
|  | ||||
| # JIRA plugin | ||||
| atlassian-ide-plugin.xml | ||||
|  | ||||
| # Cursive Clojure plugin | ||||
| .idea/replstate.xml | ||||
|  | ||||
| # SonarLint plugin | ||||
| .idea/sonarlint/ | ||||
|  | ||||
| # Crashlytics plugin (for Android Studio and IntelliJ) | ||||
| com_crashlytics_export_strings.xml | ||||
| crashlytics.properties | ||||
| crashlytics-build.properties | ||||
| fabric.properties | ||||
|  | ||||
| # Editor-based Rest Client | ||||
| .idea/httpRequests | ||||
|  | ||||
| # Android studio 3.1+ serialized cache file | ||||
| .idea/caches/build_file_checksums.ser | ||||
|  | ||||
| ### Intellij+all Patch ### | ||||
| # Ignore everything but code style settings and run configurations | ||||
| # that are supposed to be shared within teams. | ||||
|  | ||||
| .idea/* | ||||
|  | ||||
| !.idea/codeStyles | ||||
| !.idea/runConfigurations | ||||
|  | ||||
| ### Python ### | ||||
| # Byte-compiled / optimized / DLL files | ||||
| __pycache__/ | ||||
| *.py[cod] | ||||
| *$py.class | ||||
|  | ||||
| # C extensions | ||||
| *.so | ||||
|  | ||||
| # Distribution / packaging | ||||
| .Python | ||||
| build/ | ||||
| develop-eggs/ | ||||
| dist/ | ||||
| downloads/ | ||||
| eggs/ | ||||
| .eggs/ | ||||
| lib/ | ||||
| lib64/ | ||||
| parts/ | ||||
| sdist/ | ||||
| var/ | ||||
| wheels/ | ||||
| share/python-wheels/ | ||||
| *.egg-info/ | ||||
| .installed.cfg | ||||
| *.egg | ||||
| MANIFEST | ||||
|  | ||||
| # PyInstaller | ||||
| #  Usually these files are written by a python script from a template | ||||
| #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||
| *.manifest | ||||
| *.spec | ||||
|  | ||||
| # Installer logs | ||||
| pip-log.txt | ||||
| pip-delete-this-directory.txt | ||||
|  | ||||
| # Unit test / coverage reports | ||||
| htmlcov/ | ||||
| .tox/ | ||||
| .nox/ | ||||
| .coverage | ||||
| .coverage.* | ||||
| .cache | ||||
| nosetests.xml | ||||
| coverage.xml | ||||
| *.cover | ||||
| *.py,cover | ||||
| .hypothesis/ | ||||
| .pytest_cache/ | ||||
| cover/ | ||||
|  | ||||
| # Translations | ||||
| *.mo | ||||
| *.pot | ||||
|  | ||||
| # Django stuff: | ||||
| *.log | ||||
| local_settings.py | ||||
| db.sqlite3 | ||||
| db.sqlite3-journal | ||||
|  | ||||
| # Flask stuff: | ||||
| instance/ | ||||
| .webassets-cache | ||||
|  | ||||
| # Scrapy stuff: | ||||
| .scrapy | ||||
|  | ||||
| # Sphinx documentation | ||||
| docs/_build/ | ||||
|  | ||||
| # PyBuilder | ||||
| .pybuilder/ | ||||
| target/ | ||||
|  | ||||
| # Jupyter Notebook | ||||
| .ipynb_checkpoints | ||||
|  | ||||
| # IPython | ||||
| profile_default/ | ||||
| ipython_config.py | ||||
|  | ||||
| # pyenv | ||||
| #   For a library or package, you might want to ignore these files since the code is | ||||
| #   intended to run in multiple environments; otherwise, check them in: | ||||
| # .python-version | ||||
|  | ||||
| # pipenv | ||||
| #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||
| #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||
| #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||
| #   install all needed dependencies. | ||||
| #Pipfile.lock | ||||
|  | ||||
| # poetry | ||||
| #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||||
| #   This is especially recommended for binary packages to ensure reproducibility, and is more | ||||
| #   commonly ignored for libraries. | ||||
| #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||||
| #poetry.lock | ||||
|  | ||||
| # pdm | ||||
| #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||||
| #pdm.lock | ||||
| #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||||
| #   in version control. | ||||
| #   https://pdm.fming.dev/#use-with-ide | ||||
| .pdm.toml | ||||
|  | ||||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||||
| __pypackages__/ | ||||
|  | ||||
| # Celery stuff | ||||
| celerybeat-schedule | ||||
| celerybeat.pid | ||||
|  | ||||
| # SageMath parsed files | ||||
| *.sage.py | ||||
|  | ||||
| # Environments | ||||
| .env | ||||
| .venv | ||||
| env/ | ||||
| venv/ | ||||
| ENV/ | ||||
| env.bak/ | ||||
| venv.bak/ | ||||
|  | ||||
| # Spyder project settings | ||||
| .spyderproject | ||||
| .spyproject | ||||
|  | ||||
| # Rope project settings | ||||
| .ropeproject | ||||
|  | ||||
| # mkdocs documentation | ||||
| /site | ||||
|  | ||||
| # mypy | ||||
| .mypy_cache/ | ||||
| .dmypy.json | ||||
| dmypy.json | ||||
|  | ||||
| # Pyre type checker | ||||
| .pyre/ | ||||
|  | ||||
| # pytype static type analyzer | ||||
| .pytype/ | ||||
|  | ||||
| # Cython debug symbols | ||||
| cython_debug/ | ||||
|  | ||||
| # PyCharm | ||||
| #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||||
| #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||||
| #  and can be added to the global gitignore or merged into this file.  For a more nuclear | ||||
| #  option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||||
| #.idea/ | ||||
|  | ||||
| ### Python Patch ### | ||||
| # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | ||||
| poetry.toml | ||||
|  | ||||
| # ruff | ||||
| .ruff_cache/ | ||||
|  | ||||
| # LSP config files | ||||
| pyrightconfig.json | ||||
|  | ||||
| ### Vim ### | ||||
| # Swap | ||||
| [._]*.s[a-v][a-z] | ||||
| !*.svg  # comment out if you don't need vector files | ||||
| [._]*.sw[a-p] | ||||
| [._]s[a-rt-v][a-z] | ||||
| [._]ss[a-gi-z] | ||||
| [._]sw[a-p] | ||||
|  | ||||
| # Session | ||||
| Session.vim | ||||
| Sessionx.vim | ||||
|  | ||||
| # Temporary | ||||
| .netrwhist | ||||
| *~ | ||||
| # Auto-generated tag files | ||||
| tags | ||||
| # Persistent undo | ||||
| [._]*.un~ | ||||
|  | ||||
| ### VisualStudioCode ### | ||||
| .vscode/* | ||||
| !.vscode/settings.json | ||||
| !.vscode/tasks.json | ||||
| !.vscode/launch.json | ||||
| !.vscode/extensions.json | ||||
| !.vscode/*.code-snippets | ||||
|  | ||||
| # Local History for Visual Studio Code | ||||
| .history/ | ||||
|  | ||||
| # Built Visual Studio Code Extensions | ||||
| *.vsix | ||||
|  | ||||
| ### VisualStudioCode Patch ### | ||||
| # Ignore all local history of files | ||||
| .history | ||||
| .ionide | ||||
|  | ||||
| # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim | ||||
|  | ||||
							
								
								
									
										
											BIN
										
									
								
								config/__pycache__/base.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								config/__pycache__/base.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										56
									
								
								config/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								config/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| import ml_collections | ||||
|  | ||||
| def get_config(): | ||||
|  | ||||
|     config = ml_collections.ConfigDict() | ||||
|  | ||||
|     # misc | ||||
|     config.seed = 42 | ||||
|     config.logdir = "logs" | ||||
|     config.num_epochs = 100 | ||||
|     config.mixed_precision = "fp16" | ||||
|     config.allow_tf32 = True | ||||
|  | ||||
|     # pretrained model initialization | ||||
|     config.pretrained = pretrained = ml_collections.ConfigDict() | ||||
|     pretrained.model = "runwayml/stable-diffusion-v1-5" | ||||
|     pretrained.revision = "main" | ||||
|  | ||||
|     # training | ||||
|     config.train = train = ml_collections.ConfigDict() | ||||
|     train.mixed_precision = "fp16" | ||||
|     train.batch_size = 1 | ||||
|     train.use_8bit_adam = False | ||||
|     train.scale_lr = False | ||||
|     train.learning_rate = 1e-4 | ||||
|     train.adam_beta1 = 0.9 | ||||
|     train.adam_beta2 = 0.999 | ||||
|     train.adam_weight_decay = 1e-2 | ||||
|     train.adam_epsilon = 1e-8 | ||||
|     train.gradient_accumulation_steps = 1 | ||||
|     train.max_grad_norm = 1.0 | ||||
|     train.num_inner_epochs = 1 | ||||
|     train.cfg = True | ||||
|     train.adv_clip_max = 10 | ||||
|     train.clip_range = 1e-4 | ||||
|  | ||||
|     # sampling | ||||
|     config.sample = sample = ml_collections.ConfigDict() | ||||
|     sample.num_steps = 5 | ||||
|     sample.eta = 1.0 | ||||
|     sample.guidance_scale = 5.0 | ||||
|     sample.batch_size = 1 | ||||
|     sample.num_batches_per_epoch = 4 | ||||
|  | ||||
|     # prompting | ||||
|     config.prompt_fn = "imagenet_animals" | ||||
|     config.prompt_fn_kwargs = {} | ||||
|  | ||||
|     # rewards | ||||
|     config.reward_fn = "jpeg_compressibility" | ||||
|  | ||||
|     config.per_prompt_stat_tracking = ml_collections.ConfigDict() | ||||
|     config.per_prompt_stat_tracking.buffer_size = 128 | ||||
|     config.per_prompt_stat_tracking.min_count = 16 | ||||
|  | ||||
|     return config | ||||
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/prompts.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/rewards.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								ddpo_pytorch/assets/imagenet_classes.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | ||||
| # with the following modifications: | ||||
| # - | ||||
|  | ||||
| from typing import Optional, Tuple, Union | ||||
|  | ||||
| import math | ||||
| import torch | ||||
|  | ||||
| from diffusers.utils import randn_tensor | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||
|  | ||||
|  | ||||
| def ddim_step_with_logprob( | ||||
|     self: DDIMScheduler, | ||||
|     model_output: torch.FloatTensor, | ||||
|     timestep: int, | ||||
|     sample: torch.FloatTensor, | ||||
|     eta: float = 0.0, | ||||
|     use_clipped_model_output: bool = False, | ||||
|     generator=None, | ||||
|     prev_sample: Optional[torch.FloatTensor] = None, | ||||
| ) -> Union[DDIMSchedulerOutput, Tuple]: | ||||
|     """ | ||||
|     Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||||
|     process from the learned model outputs (most often the predicted noise). | ||||
|  | ||||
|     Args: | ||||
|         model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||||
|         timestep (`int`): current discrete timestep in the diffusion chain. | ||||
|         sample (`torch.FloatTensor`): | ||||
|             current instance of sample being created by diffusion process. | ||||
|         eta (`float`): weight of noise for added noise in diffusion step. | ||||
|         use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped | ||||
|             predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when | ||||
|             `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would | ||||
|             coincide with the one provided as input and `use_clipped_model_output` will have not effect. | ||||
|         generator: random number generator. | ||||
|         variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we | ||||
|             can directly provide the noise for the variance itself. This is useful for methods such as | ||||
|             CycleDiffusion. (https://arxiv.org/abs/2210.05559) | ||||
|         return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class | ||||
|  | ||||
|     Returns: | ||||
|         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: | ||||
|         [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||||
|         returning a tuple, the first element is the sample tensor. | ||||
|  | ||||
|     """ | ||||
|     assert isinstance(self, DDIMScheduler) | ||||
|     if self.num_inference_steps is None: | ||||
|         raise ValueError( | ||||
|             "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | ||||
|         ) | ||||
|  | ||||
|     # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | ||||
|     # Ideally, read DDIM paper in-detail understanding | ||||
|  | ||||
|     # Notation (<variable name> -> <name in paper> | ||||
|     # - pred_noise_t -> e_theta(x_t, t) | ||||
|     # - pred_original_sample -> f_theta(x_t, t) or x_0 | ||||
|     # - std_dev_t -> sigma_t | ||||
|     # - eta -> η | ||||
|     # - pred_sample_direction -> "direction pointing to x_t" | ||||
|     # - pred_prev_sample -> "x_t-1" | ||||
|  | ||||
|     # 1. get previous step value (=t-1) | ||||
|     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||
|  | ||||
|     # 2. compute alphas, betas | ||||
|     self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) | ||||
|     self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep) | ||||
|     alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) | ||||
|     print(timestep) | ||||
|     print(alpha_prod_t) | ||||
|     print(alpha_prod_t_prev) | ||||
|     print(prev_timestep) | ||||
|  | ||||
|     beta_prod_t = 1 - alpha_prod_t | ||||
|  | ||||
|     # 3. compute predicted original sample from predicted noise also called | ||||
|     # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     if self.config.prediction_type == "epsilon": | ||||
|         pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||||
|         pred_epsilon = model_output | ||||
|     elif self.config.prediction_type == "sample": | ||||
|         pred_original_sample = model_output | ||||
|         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||
|     elif self.config.prediction_type == "v_prediction": | ||||
|         pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||||
|         pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | ||||
|             " `v_prediction`" | ||||
|         ) | ||||
|  | ||||
|     # 4. Clip or threshold "predicted x_0" | ||||
|     if self.config.thresholding: | ||||
|         pred_original_sample = self._threshold_sample(pred_original_sample) | ||||
|     elif self.config.clip_sample: | ||||
|         pred_original_sample = pred_original_sample.clamp( | ||||
|             -self.config.clip_sample_range, self.config.clip_sample_range | ||||
|         ) | ||||
|  | ||||
|     # 5. compute variance: "sigma_t(η)" -> see formula (16) | ||||
|     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||
|     variance = self._get_variance(timestep, prev_timestep) | ||||
|     std_dev_t = eta * variance ** (0.5) | ||||
|  | ||||
|     if use_clipped_model_output: | ||||
|         # the pred_epsilon is always re-derived from the clipped x_0 in Glide | ||||
|         pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | ||||
|  | ||||
|     # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | ||||
|  | ||||
|     # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||
|     prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | ||||
|  | ||||
|     if prev_sample is not None and generator is not None: | ||||
|         raise ValueError( | ||||
|             "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" | ||||
|             " `prev_sample` stays `None`." | ||||
|         ) | ||||
|  | ||||
|     if prev_sample is None: | ||||
|         variance_noise = randn_tensor( | ||||
|             model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | ||||
|         ) | ||||
|         prev_sample = prev_sample_mean + std_dev_t * variance_noise | ||||
|  | ||||
|     # log prob of prev_sample given prev_sample_mean and std_dev_t | ||||
|     log_prob = ( | ||||
|         -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) | ||||
|         - torch.log(std_dev_t) | ||||
|         - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | ||||
|     ) | ||||
|     # mean along all but batch dimension | ||||
|     log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | ||||
|  | ||||
|     return prev_sample, log_prob | ||||
							
								
								
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py | ||||
| # with the following modifications: | ||||
| # - | ||||
|  | ||||
| from typing import Any, Callable, Dict, List, Optional, Union | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( | ||||
|     StableDiffusionPipeline, | ||||
|     rescale_noise_cfg, | ||||
| ) | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMScheduler | ||||
| from .ddim_with_logprob import ddim_step_with_logprob | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def pipeline_with_logprob( | ||||
|     self: StableDiffusionPipeline, | ||||
|     prompt: Union[str, List[str]] = None, | ||||
|     height: Optional[int] = None, | ||||
|     width: Optional[int] = None, | ||||
|     num_inference_steps: int = 50, | ||||
|     guidance_scale: float = 7.5, | ||||
|     negative_prompt: Optional[Union[str, List[str]]] = None, | ||||
|     num_images_per_prompt: Optional[int] = 1, | ||||
|     eta: float = 0.0, | ||||
|     generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||||
|     latents: Optional[torch.FloatTensor] = None, | ||||
|     prompt_embeds: Optional[torch.FloatTensor] = None, | ||||
|     negative_prompt_embeds: Optional[torch.FloatTensor] = None, | ||||
|     output_type: Optional[str] = "pil", | ||||
|     return_dict: bool = True, | ||||
|     callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||||
|     callback_steps: int = 1, | ||||
|     cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||||
|     guidance_rescale: float = 0.0, | ||||
| ): | ||||
|     r""" | ||||
|     Function invoked when calling the pipeline for generation. | ||||
|  | ||||
|     Args: | ||||
|         prompt (`str` or `List[str]`, *optional*): | ||||
|             The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | ||||
|             instead. | ||||
|         height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||
|             The height in pixels of the generated image. | ||||
|         width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||||
|             The width in pixels of the generated image. | ||||
|         num_inference_steps (`int`, *optional*, defaults to 50): | ||||
|             The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||||
|             expense of slower inference. | ||||
|         guidance_scale (`float`, *optional*, defaults to 7.5): | ||||
|             Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||||
|             `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||||
|             Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||||
|             1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||||
|             usually at the expense of lower image quality. | ||||
|         negative_prompt (`str` or `List[str]`, *optional*): | ||||
|             The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||||
|             `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | ||||
|             less than `1`). | ||||
|         num_images_per_prompt (`int`, *optional*, defaults to 1): | ||||
|             The number of images to generate per prompt. | ||||
|         eta (`float`, *optional*, defaults to 0.0): | ||||
|             Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||||
|             [`schedulers.DDIMScheduler`], will be ignored for others. | ||||
|         generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | ||||
|             One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | ||||
|             to make generation deterministic. | ||||
|         latents (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||||
|             generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||||
|             tensor will ge generated by sampling using the supplied random `generator`. | ||||
|         prompt_embeds (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | ||||
|             provided, text embeddings will be generated from `prompt` input argument. | ||||
|         negative_prompt_embeds (`torch.FloatTensor`, *optional*): | ||||
|             Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | ||||
|             weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | ||||
|             argument. | ||||
|         output_type (`str`, *optional*, defaults to `"pil"`): | ||||
|             The output format of the generate image. Choose between | ||||
|             [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||||
|         return_dict (`bool`, *optional*, defaults to `True`): | ||||
|             Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||||
|             plain tuple. | ||||
|         callback (`Callable`, *optional*): | ||||
|             A function that will be called every `callback_steps` steps during inference. The function will be | ||||
|             called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | ||||
|         callback_steps (`int`, *optional*, defaults to 1): | ||||
|             The frequency at which the `callback` function will be called. If not specified, the callback will be | ||||
|             called at every step. | ||||
|         cross_attention_kwargs (`dict`, *optional*): | ||||
|             A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||||
|             `self.processor` in | ||||
|             [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||||
|         guidance_rescale (`float`, *optional*, defaults to 0.7): | ||||
|             Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are | ||||
|             Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of | ||||
|             [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). | ||||
|             Guidance rescale factor should fix overexposure when using zero terminal SNR. | ||||
|  | ||||
|     Examples: | ||||
|  | ||||
|     Returns: | ||||
|         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||||
|         [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||||
|         When returning a tuple, the first element is a list with the generated images, and the second element is a | ||||
|         list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||||
|         (nsfw) content, according to the `safety_checker`. | ||||
|     """ | ||||
|     # 0. Default height and width to unet | ||||
|     height = height or self.unet.config.sample_size * self.vae_scale_factor | ||||
|     width = width or self.unet.config.sample_size * self.vae_scale_factor | ||||
|  | ||||
|     # 1. Check inputs. Raise error if not correct | ||||
|     self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) | ||||
|  | ||||
|     # 2. Define call parameters | ||||
|     if prompt is not None and isinstance(prompt, str): | ||||
|         batch_size = 1 | ||||
|     elif prompt is not None and isinstance(prompt, list): | ||||
|         batch_size = len(prompt) | ||||
|     else: | ||||
|         batch_size = prompt_embeds.shape[0] | ||||
|  | ||||
|     device = self._execution_device | ||||
|     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||||
|     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||||
|     # corresponds to doing no classifier free guidance. | ||||
|     do_classifier_free_guidance = guidance_scale > 1.0 | ||||
|  | ||||
|     # 3. Encode input prompt | ||||
|     text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | ||||
|     prompt_embeds = self._encode_prompt( | ||||
|         prompt, | ||||
|         device, | ||||
|         num_images_per_prompt, | ||||
|         do_classifier_free_guidance, | ||||
|         negative_prompt, | ||||
|         prompt_embeds=prompt_embeds, | ||||
|         negative_prompt_embeds=negative_prompt_embeds, | ||||
|         lora_scale=text_encoder_lora_scale, | ||||
|     ) | ||||
|  | ||||
|     # 4. Prepare timesteps | ||||
|     self.scheduler.set_timesteps(num_inference_steps, device=device) | ||||
|     timesteps = self.scheduler.timesteps | ||||
|  | ||||
|     # 5. Prepare latent variables | ||||
|     num_channels_latents = self.unet.config.in_channels | ||||
|     latents = self.prepare_latents( | ||||
|         batch_size * num_images_per_prompt, | ||||
|         num_channels_latents, | ||||
|         height, | ||||
|         width, | ||||
|         prompt_embeds.dtype, | ||||
|         device, | ||||
|         generator, | ||||
|         latents, | ||||
|     ) | ||||
|  | ||||
|     # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | ||||
|     extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | ||||
|  | ||||
|     # 7. Denoising loop | ||||
|     num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||||
|     all_latents = [latents] | ||||
|     all_log_probs = [] | ||||
|     with self.progress_bar(total=num_inference_steps) as progress_bar: | ||||
|         for i, t in enumerate(timesteps): | ||||
|             # expand the latents if we are doing classifier free guidance | ||||
|             latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||||
|             latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||||
|  | ||||
|             # predict the noise residual | ||||
|             noise_pred = self.unet( | ||||
|                 latent_model_input, | ||||
|                 t, | ||||
|                 encoder_hidden_states=prompt_embeds, | ||||
|                 cross_attention_kwargs=cross_attention_kwargs, | ||||
|                 return_dict=False, | ||||
|             )[0] | ||||
|  | ||||
|             # perform guidance | ||||
|             if do_classifier_free_guidance: | ||||
|                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||||
|  | ||||
|             if do_classifier_free_guidance and guidance_rescale > 0.0: | ||||
|                 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||||
|                 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | ||||
|  | ||||
|             # compute the previous noisy sample x_t -> x_t-1 | ||||
|             latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) | ||||
|  | ||||
|             all_latents.append(latents) | ||||
|             all_log_probs.append(log_prob) | ||||
|  | ||||
|             # call the callback, if provided | ||||
|             if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||||
|                 progress_bar.update() | ||||
|                 if callback is not None and i % callback_steps == 0: | ||||
|                     callback(i, t, latents) | ||||
|  | ||||
|     if not output_type == "latent": | ||||
|         image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||||
|         image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | ||||
|     else: | ||||
|         image = latents | ||||
|         has_nsfw_concept = None | ||||
|  | ||||
|     if has_nsfw_concept is None: | ||||
|         do_denormalize = [True] * image.shape[0] | ||||
|     else: | ||||
|         do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | ||||
|  | ||||
|     image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||||
|  | ||||
|     # Offload last model to CPU | ||||
|     if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||||
|         self.final_offload_hook.offload() | ||||
|  | ||||
|     return image, has_nsfw_concept, all_latents, all_log_probs | ||||
							
								
								
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								ddpo_pytorch/prompts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| from importlib import resources | ||||
| import functools | ||||
| import random | ||||
| import inflect | ||||
|  | ||||
| IE = inflect.engine() | ||||
| ASSETS_PATH = resources.files("ddpo_pytorch.assets") | ||||
|  | ||||
|  | ||||
| @functools.cache | ||||
| def load_lines(name): | ||||
|     with ASSETS_PATH.joinpath(name).open() as f: | ||||
|         return [line.strip() for line in f.readlines()] | ||||
|  | ||||
|  | ||||
| def imagenet(low, high): | ||||
|     return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} | ||||
|  | ||||
|  | ||||
| def imagenet_all(): | ||||
|     return imagenet(0, 1000) | ||||
|  | ||||
|  | ||||
| def imagenet_animals(): | ||||
|     return imagenet(0, 398) | ||||
|  | ||||
|  | ||||
| def imagenet_dogs(): | ||||
|     return imagenet(151, 269) | ||||
|  | ||||
|  | ||||
| def nouns_activities(nouns_file, activities_file): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     activities = load_lines(activities_file) | ||||
|     return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} | ||||
|  | ||||
|  | ||||
| def counting(nouns_file, low, high): | ||||
|     nouns = load_lines(nouns_file) | ||||
|     number = IE.number_to_words(random.randint(low, high)) | ||||
|     noun = random.choice(nouns) | ||||
|     plural_noun = IE.plural(noun) | ||||
|     prompt = f"{number} {plural_noun}" | ||||
|     metadata = { | ||||
|         "questions": [ | ||||
|             f"How many {plural_noun} are there in this image?", | ||||
|             f"What animal is in this image?", | ||||
|         ], | ||||
|         "answers": [ | ||||
|             number, | ||||
|             noun, | ||||
|         ], | ||||
|     } | ||||
|     return prompt, metadata | ||||
							
								
								
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								ddpo_pytorch/rewards.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from PIL import Image | ||||
| import io | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def jpeg_incompressibility(): | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC | ||||
|         images = [Image.fromarray(image) for image in images] | ||||
|         buffers = [io.BytesIO() for _ in images] | ||||
|         for image, buffer in zip(images, buffers): | ||||
|             image.save(buffer, format="JPEG", quality=95) | ||||
|         sizes = [buffer.tell() / 1000 for buffer in buffers] | ||||
|         return np.array(sizes), {} | ||||
|  | ||||
|     return _fn | ||||
|  | ||||
|  | ||||
| def jpeg_compressibility(): | ||||
|     jpeg_fn = jpeg_incompressibility() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         rew, meta = jpeg_fn(images, prompts, metadata) | ||||
|         return -rew, meta | ||||
|  | ||||
|     return _fn | ||||
							
								
								
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								ddpo_pytorch/stat_tracking.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| import numpy as np | ||||
| from collections import deque | ||||
|  | ||||
|  | ||||
| class PerPromptStatTracker: | ||||
|     def __init__(self, buffer_size, min_count): | ||||
|         self.buffer_size = buffer_size | ||||
|         self.min_count = min_count | ||||
|         self.stats = {} | ||||
|  | ||||
|     def update(self, prompts, rewards): | ||||
|         unique = np.unique(prompts) | ||||
|         advantages = np.empty_like(rewards) | ||||
|         for prompt in unique: | ||||
|             prompt_rewards = rewards[prompts == prompt] | ||||
|             if prompt not in self.stats: | ||||
|                 self.stats[prompt] = deque(maxlen=self.buffer_size) | ||||
|             self.stats[prompt].extend(prompt_rewards) | ||||
|  | ||||
|             if len(self.stats[prompt]) < self.min_count: | ||||
|                 mean = np.mean(rewards) | ||||
|                 std = np.std(rewards) + 1e-6 | ||||
|             else: | ||||
|                 mean = np.mean(self.stats[prompt]) | ||||
|                 std = np.std(self.stats[prompt]) + 1e-6 | ||||
|             advantages[prompts == prompt] = (prompt_rewards - mean) / std | ||||
|  | ||||
|         return advantages | ||||
|  | ||||
|     def get_stats(self): | ||||
|         return { | ||||
|             k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} | ||||
|             for k, v in self.stats.items() | ||||
|         } | ||||
							
								
								
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								scripts/train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,341 @@ | ||||
| from absl import app, flags, logging | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed | ||||
| from accelerate.logging import get_logger | ||||
| from diffusers import StableDiffusionPipeline, DDIMScheduler | ||||
| from diffusers.loaders import AttnProcsLayers | ||||
| from diffusers.models.attention_processor import LoRAAttnProcessor | ||||
| import ddpo_pytorch.prompts | ||||
| import ddpo_pytorch.rewards | ||||
| from ddpo_pytorch.stat_tracking import PerPromptStatTracker | ||||
| from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob | ||||
| from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob | ||||
| import torch | ||||
| import tqdm | ||||
|  | ||||
|  | ||||
| FLAGS = flags.FLAGS | ||||
| config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") | ||||
|  | ||||
| logger = get_logger(__name__) | ||||
|  | ||||
|  | ||||
| def main(_): | ||||
|     # basic Accelerate and logging setup | ||||
|     config = FLAGS.config | ||||
|     accelerator = Accelerator( | ||||
|         log_with="all", | ||||
|         mixed_precision=config.mixed_precision, | ||||
|         project_dir=config.logdir, | ||||
|     ) | ||||
|     if accelerator.is_main_process: | ||||
|         accelerator.init_trackers(project_name="ddpo-pytorch", config=config) | ||||
|     logger.info(config) | ||||
|  | ||||
|     # set seed | ||||
|     set_seed(config.seed) | ||||
|  | ||||
|     # load scheduler, tokenizer and models. | ||||
|     pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) | ||||
|     # freeze parameters of models to save more memory | ||||
|     pipeline.unet.requires_grad_(False) | ||||
|     pipeline.vae.requires_grad_(False) | ||||
|     pipeline.text_encoder.requires_grad_(False) | ||||
|     # disable safety checker | ||||
|     pipeline.safety_checker = None | ||||
|     # make the progress bar nicer | ||||
|     pipeline.set_progress_bar_config( | ||||
|         position=1, | ||||
|         disable=not accelerator.is_local_main_process, | ||||
|         leave=False, | ||||
|     ) | ||||
|     # switch to DDIM scheduler | ||||
|     pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | ||||
|  | ||||
|     # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision | ||||
|     # as these weights are only used for inference, keeping weights in full precision is not required. | ||||
|     weight_dtype = torch.float32 | ||||
|     if accelerator.mixed_precision == "fp16": | ||||
|         weight_dtype = torch.float16 | ||||
|     elif accelerator.mixed_precision == "bf16": | ||||
|         weight_dtype = torch.bfloat16 | ||||
|  | ||||
|     # Move unet, vae and text_encoder to device and cast to weight_dtype | ||||
|     pipeline.unet.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.vae.to(accelerator.device, dtype=weight_dtype) | ||||
|     pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) | ||||
|  | ||||
|     # Set correct lora layers | ||||
|     lora_attn_procs = {} | ||||
|     for name in pipeline.unet.attn_processors.keys(): | ||||
|         cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim | ||||
|         if name.startswith("mid_block"): | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[-1] | ||||
|         elif name.startswith("up_blocks"): | ||||
|             block_id = int(name[len("up_blocks.")]) | ||||
|             hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] | ||||
|         elif name.startswith("down_blocks"): | ||||
|             block_id = int(name[len("down_blocks.")]) | ||||
|             hidden_size = pipeline.unet.config.block_out_channels[block_id] | ||||
|  | ||||
|         lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) | ||||
|  | ||||
|     pipeline.unet.set_attn_processor(lora_attn_procs) | ||||
|     lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) | ||||
|  | ||||
|     # Enable TF32 for faster training on Ampere GPUs, | ||||
|     # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||||
|     if config.allow_tf32: | ||||
|         torch.backends.cuda.matmul.allow_tf32 = True | ||||
|  | ||||
|     if config.train.scale_lr: | ||||
|         config.train.learning_rate = ( | ||||
|             config.train.learning_rate | ||||
|             * config.train.gradient_accumulation_steps | ||||
|             * config.train.batch_size | ||||
|             * accelerator.num_processes | ||||
|         ) | ||||
|  | ||||
|     # Initialize the optimizer | ||||
|     if config.train.use_8bit_adam: | ||||
|         try: | ||||
|             import bitsandbytes as bnb | ||||
|         except ImportError: | ||||
|             raise ImportError( | ||||
|                 "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | ||||
|             ) | ||||
|  | ||||
|         optimizer_cls = bnb.optim.AdamW8bit | ||||
|     else: | ||||
|         optimizer_cls = torch.optim.AdamW | ||||
|  | ||||
|     optimizer = optimizer_cls( | ||||
|         lora_layers.parameters(), | ||||
|         lr=config.train.learning_rate, | ||||
|         betas=(config.train.adam_beta1, config.train.adam_beta2), | ||||
|         weight_decay=config.train.adam_weight_decay, | ||||
|         eps=config.train.adam_epsilon, | ||||
|     ) | ||||
|  | ||||
|     # prepare prompt and reward fn | ||||
|     prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) | ||||
|     reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() | ||||
|  | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) | ||||
|  | ||||
|     # Train! | ||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||
|     total_train_batch_size = ( | ||||
|         config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps | ||||
|     ) | ||||
|  | ||||
|     assert config.sample.batch_size % config.train.batch_size == 0 | ||||
|     assert samples_per_epoch % total_train_batch_size == 0 | ||||
|  | ||||
|     logger.info("***** Running training *****") | ||||
|     logger.info(f"  Num Epochs = {config.num_epochs}") | ||||
|     logger.info(f"  Sample batch size per device = {config.sample.batch_size}") | ||||
|     logger.info(f"  Train batch size per device = {config.train.batch_size}") | ||||
|     logger.info(f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") | ||||
|     logger.info("") | ||||
|     logger.info(f"  Total number of samples per epoch = {samples_per_epoch}") | ||||
|     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") | ||||
|     logger.info(f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") | ||||
|     logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}") | ||||
|  | ||||
|     neg_prompt_embed = pipeline.text_encoder( | ||||
|         pipeline.tokenizer( | ||||
|             [""], | ||||
|             return_tensors="pt", | ||||
|             padding="max_length", | ||||
|             truncation=True, | ||||
|             max_length=pipeline.tokenizer.model_max_length, | ||||
|         ).input_ids.to(accelerator.device) | ||||
|     )[0] | ||||
|     sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) | ||||
|     train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) | ||||
|  | ||||
|     if config.per_prompt_stat_tracking: | ||||
|         stat_tracker = PerPromptStatTracker( | ||||
|             config.per_prompt_stat_tracking.buffer_size, | ||||
|             config.per_prompt_stat_tracking.min_count, | ||||
|         ) | ||||
|  | ||||
|     for epoch in range(config.num_epochs): | ||||
|         #################### SAMPLING #################### | ||||
|         samples = [] | ||||
|         prompts = [] | ||||
|         for i in tqdm.tqdm( | ||||
|             range(config.sample.num_batches_per_epoch), | ||||
|             desc=f"Epoch {epoch}: sampling", | ||||
|             disable=not accelerator.is_local_main_process, | ||||
|             position=0, | ||||
|         ): | ||||
|             # generate prompts | ||||
|             prompts, prompt_metadata = zip( | ||||
|                 *[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)] | ||||
|             ) | ||||
|  | ||||
|             # encode prompts | ||||
|             prompt_ids = pipeline.tokenizer( | ||||
|                 prompts, | ||||
|                 return_tensors="pt", | ||||
|                 padding="max_length", | ||||
|                 truncation=True, | ||||
|                 max_length=pipeline.tokenizer.model_max_length, | ||||
|             ).input_ids.to(accelerator.device) | ||||
|             prompt_embeds = pipeline.text_encoder(prompt_ids)[0] | ||||
|  | ||||
|             # sample | ||||
|             pipeline.unet.eval() | ||||
|             pipeline.vae.eval() | ||||
|             images, _, latents, log_probs = pipeline_with_logprob( | ||||
|                 pipeline, | ||||
|                 prompt_embeds=prompt_embeds, | ||||
|                 negative_prompt_embeds=sample_neg_prompt_embeds, | ||||
|                 num_inference_steps=config.sample.num_steps, | ||||
|                 guidance_scale=config.sample.guidance_scale, | ||||
|                 eta=config.sample.eta, | ||||
|                 output_type="pt", | ||||
|             ) | ||||
|  | ||||
|             latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64) | ||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||
|             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||
|  | ||||
|             # compute rewards | ||||
|             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) | ||||
|  | ||||
|             samples.append( | ||||
|                 { | ||||
|                     "prompt_ids": prompt_ids, | ||||
|                     "prompt_embeds": prompt_embeds, | ||||
|                     "timesteps": timesteps, | ||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||
|                     "log_probs": log_probs, | ||||
|                     "rewards": torch.as_tensor(rewards), | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
|         # gather rewards across processes | ||||
|         rewards = accelerator.gather(samples["rewards"]).cpu().numpy() | ||||
|  | ||||
|         # per-prompt mean/std tracking | ||||
|         if config.per_prompt_stat_tracking: | ||||
|             # gather the prompts across processes | ||||
|             prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() | ||||
|             prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) | ||||
|             advantages = stat_tracker.update(prompts, rewards) | ||||
|         else: | ||||
|             advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | ||||
|  | ||||
|         # ungather advantages; we only need to keep the entries corresponding to the samples on this process | ||||
|         samples["advantages"] = ( | ||||
|             torch.as_tensor(advantages) | ||||
|             .reshape(accelerator.num_processes, -1)[accelerator.process_index] | ||||
|             .to(accelerator.device) | ||||
|         ) | ||||
|  | ||||
|         del samples["rewards"] | ||||
|         del samples["prompt_ids"] | ||||
|  | ||||
|         total_batch_size, num_timesteps = samples["timesteps"].shape | ||||
|         assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch | ||||
|         assert num_timesteps == config.sample.num_steps | ||||
|  | ||||
|         #################### TRAINING #################### | ||||
|         for inner_epoch in range(config.train.num_inner_epochs): | ||||
|             # shuffle samples along batch dimension | ||||
|             indices = torch.randperm(total_batch_size, device=accelerator.device) | ||||
|             samples = {k: v[indices] for k, v in samples.items()} | ||||
|  | ||||
|             # shuffle along time dimension, independently for each sample | ||||
|             for i in range(total_batch_size): | ||||
|                 indices = torch.randperm(num_timesteps, device=accelerator.device) | ||||
|                 for key in ["timesteps", "latents", "next_latents"]: | ||||
|                     samples[key][i] = samples[key][i][indices] | ||||
|  | ||||
|             # rebatch for training | ||||
|             samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} | ||||
|  | ||||
|             # dict of lists -> list of dicts for easier iteration | ||||
|             samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] | ||||
|  | ||||
|             # train | ||||
|             for i, sample in tqdm.tqdm( | ||||
|                 list(enumerate(samples_batched)), | ||||
|                 desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", | ||||
|                 position=0, | ||||
|             ): | ||||
|                 if config.train.cfg: | ||||
|                     # concat negative prompts to sample prompts to avoid two forward passes | ||||
|                     embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]]) | ||||
|                 else: | ||||
|                     embeds = sample["prompt_embeds"] | ||||
|  | ||||
|                 for j in tqdm.trange( | ||||
|                     num_timesteps, | ||||
|                     desc=f"Timestep", | ||||
|                     position=1, | ||||
|                     leave=False, | ||||
|                 ): | ||||
|                     with accelerator.accumulate(pipeline.unet): | ||||
|                         if config.train.cfg: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 torch.cat([sample["latents"][:, j]] * 2), | ||||
|                                 torch.cat([sample["timesteps"][:, j]] * 2), | ||||
|                                 embeds, | ||||
|                             ).sample | ||||
|                             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||||
|                             noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( | ||||
|                                 noise_pred_text - noise_pred_uncond | ||||
|                             ) | ||||
|                         else: | ||||
|                             noise_pred = pipeline.unet( | ||||
|                                 sample["latents"][:, j], sample["timesteps"][:, j], embeds | ||||
|                             ).sample | ||||
|                         _, log_prob = ddim_step_with_logprob( | ||||
|                             pipeline.scheduler, | ||||
|                             noise_pred, | ||||
|                             sample["timesteps"][:, j], | ||||
|                             sample["latents"][:, j], | ||||
|                             eta=config.sample.eta, | ||||
|                             prev_sample=sample["next_latents"][:, j], | ||||
|                         ) | ||||
|  | ||||
|                         # ppo logic | ||||
|                         advantages = torch.clamp( | ||||
|                             sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max | ||||
|                         ) | ||||
|                         ratio = torch.exp(log_prob - sample["log_probs"][:, j]) | ||||
|                         unclipped_loss = -advantages * ratio | ||||
|                         clipped_loss = -advantages * torch.clamp( | ||||
|                             ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range | ||||
|                         ) | ||||
|                         loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | ||||
|  | ||||
|                         # debugging values | ||||
|                         info = {} | ||||
|                         # John Schulman says that (ratio - 1) - log(ratio) is a better | ||||
|                         # estimator, but most existing code uses this so... | ||||
|                         # http://joschu.net/blog/kl-approx.html | ||||
|                         info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) | ||||
|                         info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) | ||||
|                         info["loss"] = loss | ||||
|  | ||||
|                         # backward pass | ||||
|                         accelerator.backward(loss) | ||||
|                         if accelerator.sync_gradients: | ||||
|                             accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) | ||||
|                         optimizer.step() | ||||
|                         optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     app.run(main) | ||||
		Reference in New Issue
	
	Block a user