Continue implementation
This commit is contained in:
		| @@ -1,6 +1,9 @@ | ||||
| # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py | ||||
| # with the following modifications: | ||||
| # - | ||||
| # - It computes and returns the log prob of `prev_sample` given the UNet prediction. | ||||
| # - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided, | ||||
| #   it uses it to compute the log prob. | ||||
| # - Timesteps can be a batched torch.Tensor. | ||||
|  | ||||
| from typing import Optional, Tuple, Union | ||||
|  | ||||
| @@ -11,6 +14,19 @@ from diffusers.utils import randn_tensor | ||||
| from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler | ||||
|  | ||||
|  | ||||
| def _get_variance(self, timestep, prev_timestep): | ||||
|     alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) | ||||
|     alpha_prod_t_prev = torch.where( | ||||
|         prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod | ||||
|     ).to(timestep.device) | ||||
|     beta_prod_t = 1 - alpha_prod_t | ||||
|     beta_prod_t_prev = 1 - alpha_prod_t_prev | ||||
|  | ||||
|     variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | ||||
|  | ||||
|     return variance | ||||
|  | ||||
|  | ||||
| def ddim_step_with_logprob( | ||||
|     self: DDIMScheduler, | ||||
|     model_output: torch.FloatTensor, | ||||
| @@ -66,16 +82,13 @@ def ddim_step_with_logprob( | ||||
|  | ||||
|     # 1. get previous step value (=t-1) | ||||
|     prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||
|     prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) | ||||
|  | ||||
|     # 2. compute alphas, betas | ||||
|     self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) | ||||
|     self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep) | ||||
|     alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) | ||||
|     print(timestep) | ||||
|     print(alpha_prod_t) | ||||
|     print(alpha_prod_t_prev) | ||||
|     print(prev_timestep) | ||||
|     alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()).to(timestep.device) | ||||
|     alpha_prod_t_prev = torch.where( | ||||
|         prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod | ||||
|     ).to(timestep.device) | ||||
|  | ||||
|     beta_prod_t = 1 - alpha_prod_t | ||||
|  | ||||
| @@ -106,7 +119,7 @@ def ddim_step_with_logprob( | ||||
|  | ||||
|     # 5. compute variance: "sigma_t(η)" -> see formula (16) | ||||
|     # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | ||||
|     variance = self._get_variance(timestep, prev_timestep) | ||||
|     variance = _get_variance(self, timestep, prev_timestep) | ||||
|     std_dev_t = eta * variance ** (0.5) | ||||
|  | ||||
|     if use_clipped_model_output: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user