"""Abstract SDE classes, Reverse SDE, and VP SDEs.""" import abc import torch import numpy as np class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" def __init__(self, N): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.N = N @property @abc.abstractmethod def T(self): """End time of the SDE.""" pass @abc.abstractmethod def sde(self, x, t): pass @abc.abstractmethod def marginal_prob(self, x, t): """Parameters to determine the marginal distribution of the SDE, $p_t(x)$""" pass @abc.abstractmethod def prior_sampling(self, shape): """Generate one sample from the prior distribution, $p_T(x)$.""" pass @abc.abstractmethod def prior_logp(self, z, mask): """Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE. Args: z: latent code Returns: log probability density """ pass def discretize(self, x, t): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probability flow sampling. Defaults to Euler-Maruyama discretization. Args: x: a torch tensor t: a torch float representing the time step (from 0 to `self.T`) Returns: f, G """ dt = 1 / self.N drift, diffusion = self.sde(x, t) f = drift * dt G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) return f, G def reverse(self, score_fn, probability_flow=False): """Create the reverse-time SDE/ODE. Args: score_fn: A time-dependent score-based model that takes x and t and returns the score. probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. """ N = self.N T = self.T sde_fn = self.sde discretize_fn = self.discretize # Build the class for reverse-time SDE. class RSDE(self.__class__): def __init__(self): self.N = N self.probability_flow = probability_flow @property def T(self): return T def sde(self, x, t, *args, **kwargs): """Create the drift and diffusion functions for the reverse SDE/ODE.""" drift, diffusion = sde_fn(x, t) score = score_fn(x, t, *args, **kwargs) drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) # Set the diffusion function to zero for ODEs. diffusion = 0. if self.probability_flow else diffusion return drift, diffusion ''' def sde_score(self, x, t, score): """Create the drift and diffusion functions for the reverse SDE/ODE, given score values.""" drift, diffusion = sde_fn(x, t) if len(score.shape) == 4: drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) elif len(score.shape) == 3: drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) else: raise ValueError diffusion = 0. if self.probability_flow else diffusion return drift, diffusion ''' def discretize(self, x, t, *args, **kwargs): """Create discretized iteration rules for the reverse diffusion sampler.""" f, G = discretize_fn(x, t) rev_f = f - G[:, None, None] ** 2 * score_fn(x, t, *args, **kwargs) * \ (0.5 if self.probability_flow else 1.) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G ''' def discretize_score(self, x, t, score): """Create discretized iteration rules for the reverse diffusion sampler, given score values.""" f, G = discretize_fn(x, t) if len(score.shape) == 4: rev_f = f - G[:, None, None, None] ** 2 * score * \ (0.5 if self.probability_flow else 1.) elif len(score.shape) == 3: rev_f = f - G[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) else: raise ValueError rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G ''' return RSDE() class VPSDE(SDE): def __init__(self, beta_min=0.1, beta_max=20, N=1000): """Construct a Variance Preserving SDE. Args: beta_min: value of beta(0) beta_max: value of beta(1) N: number of discretization steps """ super().__init__(N) self.beta_0 = beta_min self.beta_1 = beta_max self.N = N self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) self.alphas = 1. - self.discrete_betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) @property def T(self): return 1 def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) if len(x.shape) == 4: drift = -0.5 * beta_t[:, None, None, None] * x elif len(x.shape) == 3: drift = -0.5 * beta_t[:, None, None] * x else: raise NotImplementedError diffusion = torch.sqrt(beta_t) return drift, diffusion def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 if len(x.shape) == 4: mean = torch.exp(log_mean_coeff[:, None, None, None]) * x elif len(x.shape) == 3: mean = torch.exp(log_mean_coeff[:, None, None]) * x else: raise ValueError("The shape of x in marginal_prob is not correct.") std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) return mean, std # def log_snr(self, t): # log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 # mean = torch.exp(log_mean_coeff) # std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) # log_snr = torch.log(mean / std) # return log_snr, mean, std # # def log_snr_np(self, t): # log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 # mean = np.exp(log_mean_coeff) # std = np.sqrt(1. - np.exp(2. * log_mean_coeff)) # log_snr = np.log(mean / std) # return log_snr # # def lambda2t(self, lambda_ori): # log_val = torch.log(torch.exp(-2. * lambda_ori) + 1.) # t = 2. * log_val / (torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * log_val) + self.beta_0) # return t # # def lambda2t_np(self, lambda_ori): # log_val = np.log(np.exp(-2. * lambda_ori) + 1.) # t = 2. * log_val / (np.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * log_val) + self.beta_0) # return t # def prior_sampling(self, shape): # sample = torch.randn(*shape) # if len(shape) == 4: # sample = torch.tril(sample, -1) # sample = sample + sample.transpose(-1, -2) # return sample def prior_sampling(self, shape): return torch.randn(*shape) def prior_logp(self, z, mask): N = torch.sum(mask, dim=tuple(range(1, len(mask.shape)))) logps = -N / 2. * np.log(2 * np.pi) - torch.sum((z * mask) ** 2, dim=(1, 2, 3)) / 2. return logps def discretize(self, x, t): """DDPM discretization.""" timestep = (t * (self.N - 1) / self.T).long() beta = self.discrete_betas.to(x.device)[timestep] alpha = self.alphas.to(x.device)[timestep] sqrt_beta = torch.sqrt(beta) if len(x.shape) == 4: f = torch.sqrt(alpha)[:, None, None, None] * x - x elif len(x.shape) == 3: f = torch.sqrt(alpha)[:, None, None] * x - x else: NotImplementedError G = sqrt_beta return f, G class subVPSDE(SDE): def __init__(self, beta_min=0.1, beta_max=20, N=1000): """Construct the sub-VP SDE that excels at likelihoods. Args: beta_min: value of beta(0) beta_max: value of beta(1) N: number of discretization steps """ super().__init__(N) self.beta_0 = beta_min self.beta_1 = beta_max self.N = N @property def T(self): return 1 def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t[:, None, None] * x discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) diffusion = torch.sqrt(beta_t * discount) return drift, diffusion def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 mean = torch.exp(log_mean_coeff)[:, None, None] * x std = 1 - torch.exp(2. * log_mean_coeff) return mean, std def prior_sampling(self, shape): return torch.randn(*shape) def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. class VESDE(SDE): def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): """Construct a Variance Exploding SDE. Args: sigma_min: smallest sigma. sigma_max: largest sigma. N: number of discretization steps """ super().__init__(N) self.sigma_min = sigma_min self.sigma_max = sigma_max self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) self.N = N @property def T(self): return 1 def sde(self, x, t): sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t drift = torch.zeros_like(x) diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device)) return drift, diffusion def marginal_prob(self, x, t): std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t mean = x return mean, std def prior_sampling(self, shape): return torch.randn(*shape) * self.sigma_max def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) def discretize(self, x, t): """SMLD(NCSN) discretization.""" timestep = (t * (self.N - 1) / self.T).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep.cpu() - 1].to(t.device)) f = torch.zeros_like(x) G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) return f, G