1215 lines
57 KiB
Python
1215 lines
57 KiB
Python
"""Various sampling methods."""
|
|
|
|
import functools
|
|
|
|
import torch
|
|
import numpy as np
|
|
import abc
|
|
import sys
|
|
import os
|
|
|
|
from models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
|
|
|
|
from scipy import integrate
|
|
from torchdiffeq import odeint
|
|
import sde_lib
|
|
from models import utils as mutils
|
|
from tqdm import trange
|
|
|
|
from datasets_nas import MetaTestDataset
|
|
# from configs.ckpt import META_DATAROOT_NB201, META_DATAROOT_OFA
|
|
from all_path import PROCESSED_DATA_PATH
|
|
|
|
_CORRECTORS = {}
|
|
_PREDICTORS = {}
|
|
|
|
|
|
def register_predictor(cls=None, *, name=None):
|
|
"""A decorator for registering predictor classes."""
|
|
|
|
def _register(cls):
|
|
if name is None:
|
|
local_name = cls.__name__
|
|
else:
|
|
local_name = name
|
|
if local_name in _PREDICTORS:
|
|
raise ValueError(f'Already registered predictor with name: {local_name}')
|
|
_PREDICTORS[local_name] = cls
|
|
return cls
|
|
|
|
if cls is None:
|
|
return _register
|
|
else:
|
|
return _register(cls)
|
|
|
|
|
|
def register_corrector(cls=None, *, name=None):
|
|
"""A decorator for registering corrector classes."""
|
|
|
|
def _register(cls):
|
|
if name is None:
|
|
local_name = cls.__name__
|
|
else:
|
|
local_name = name
|
|
if local_name in _CORRECTORS:
|
|
raise ValueError(f'Already registered corrector with name: {local_name}')
|
|
_CORRECTORS[local_name] = cls
|
|
return cls
|
|
|
|
if cls is None:
|
|
return _register
|
|
else:
|
|
return _register(cls)
|
|
|
|
|
|
def get_predictor(name):
|
|
return _PREDICTORS[name]
|
|
|
|
|
|
def get_corrector(name):
|
|
return _CORRECTORS[name]
|
|
|
|
|
|
def get_sampling_fn(
|
|
config, sde, shape, inverse_scaler, eps, data, conditional=False,
|
|
p=1, prod_w=False, weight_ratio_abs=False,
|
|
is_meta=False, data_name='cifar10', num_sample=20, is_multi_obj=False):
|
|
"""Create a sampling function.
|
|
|
|
Args:
|
|
config: A `ml_collections.ConfigDict` object that contains all configuration information.
|
|
sde: A `sde_lib.SDE` object that represents the forward SDE.
|
|
shape: A sequence of integers representing the expected shape of a single sample.
|
|
inverse_scaler: The inverse data normalizer function.
|
|
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
|
|
|
|
Returns:
|
|
A function that takes random states and a replicated training state and outputs samples with the
|
|
trailing dimensions matching `shape`.
|
|
"""
|
|
|
|
sampler_name = config.sampling.method
|
|
# Probability flow ODE sampling with black-box ODE solvers
|
|
if sampler_name.lower() == 'ode':
|
|
sampling_fn = get_ode_sampler(sde=sde,
|
|
shape=shape,
|
|
inverse_scaler=inverse_scaler,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
rtol=config.sampling.rtol,
|
|
atol=config.sampling.atol,
|
|
device=config.device)
|
|
elif sampler_name.lower() == 'diffeq':
|
|
sampling_fn = get_diffeq_sampler(sde=sde,
|
|
shape=shape,
|
|
inverse_scaler=inverse_scaler,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
rtol=config.sampling.rtol,
|
|
atol=config.sampling.atol,
|
|
step_size=config.sampling.ode_step,
|
|
method=config.sampling.ode_method,
|
|
device=config.device)
|
|
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
|
|
elif sampler_name.lower() == 'pc':
|
|
predictor = get_predictor(config.sampling.predictor.lower())
|
|
corrector = get_corrector(config.sampling.corrector.lower())
|
|
# print(config.sampling.predictor.lower(), config.sampling.corrector.lower())
|
|
if data in ['NASBench201', 'ofa']:
|
|
if is_meta:
|
|
sampling_fn = get_pc_conditional_sampler_meta_nas(sde=sde,
|
|
shape=shape,
|
|
predictor=predictor,
|
|
corrector=corrector,
|
|
inverse_scaler=inverse_scaler,
|
|
snr=config.sampling.snr,
|
|
n_steps=config.sampling.n_steps_each,
|
|
probability_flow=config.sampling.probability_flow,
|
|
continuous=config.training.continuous,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
device=config.device,
|
|
regress=config.sampling.regress,
|
|
labels=config.sampling.labels,
|
|
classifier_scale=config.sampling.classifier_scale,
|
|
weight_scheduling=config.sampling.weight_scheduling,
|
|
weight_ratio=config.sampling.weight_ratio,
|
|
t_spot=config.sampling.t_spot,
|
|
t_spot_end=config.sampling.t_spot_end,
|
|
p=p,
|
|
prod_w=prod_w,
|
|
weight_ratio_abs=weight_ratio_abs,
|
|
data_name=data_name,
|
|
num_sample=num_sample,
|
|
search_space=config.data.name)
|
|
elif is_multi_obj:
|
|
sampling_fn = get_pc_conditional_sampler_nas(sde=sde,
|
|
shape=shape,
|
|
predictor=predictor,
|
|
corrector=corrector,
|
|
inverse_scaler=inverse_scaler,
|
|
snr=config.sampling.snr,
|
|
n_steps=config.sampling.n_steps_each,
|
|
probability_flow=config.sampling.probability_flow,
|
|
continuous=config.training.continuous,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
device=config.device,
|
|
regress=config.sampling.regress,
|
|
labels=config.sampling.labels,
|
|
classifier_scale=config.sampling.classifier_scale,
|
|
weight_scheduling=config.sampling.weight_scheduling,
|
|
weight_ratio=config.sampling.weight_ratio,
|
|
t_spot=config.sampling.t_spot,
|
|
t_spot_end=config.sampling.t_spot_end,
|
|
p=p,
|
|
prod_w=prod_w,
|
|
weight_ratio_abs=weight_ratio_abs)
|
|
elif conditional:
|
|
sampling_fn = get_pc_conditional_sampler_nas(sde=sde,
|
|
shape=shape,
|
|
predictor=predictor,
|
|
corrector=corrector,
|
|
inverse_scaler=inverse_scaler,
|
|
snr=config.sampling.snr,
|
|
n_steps=config.sampling.n_steps_each,
|
|
probability_flow=config.sampling.probability_flow,
|
|
continuous=config.training.continuous,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
device=config.device,
|
|
regress=config.sampling.regress,
|
|
labels=config.sampling.labels,
|
|
classifier_scale=config.sampling.classifier_scale,
|
|
weight_scheduling=config.sampling.weight_scheduling,
|
|
weight_ratio=config.sampling.weight_ratio,
|
|
t_spot=config.sampling.t_spot,
|
|
t_spot_end=config.sampling.t_spot_end,
|
|
p=p,
|
|
prod_w=prod_w,
|
|
weight_ratio_abs=weight_ratio_abs)
|
|
else:
|
|
sampling_fn = get_pc_sampler_nas(sde=sde,
|
|
shape=shape,
|
|
predictor=predictor,
|
|
corrector=corrector,
|
|
inverse_scaler=inverse_scaler,
|
|
snr=config.sampling.snr,
|
|
n_steps=config.sampling.n_steps_each,
|
|
probability_flow=config.sampling.probability_flow,
|
|
continuous=config.training.continuous,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
device=config.device)
|
|
|
|
else:
|
|
sampling_fn = get_pc_sampler(sde=sde,
|
|
shape=shape,
|
|
predictor=predictor,
|
|
corrector=corrector,
|
|
inverse_scaler=inverse_scaler,
|
|
snr=config.sampling.snr,
|
|
n_steps=config.sampling.n_steps_each,
|
|
probability_flow=config.sampling.probability_flow,
|
|
continuous=config.training.continuous,
|
|
denoise=config.sampling.noise_removal,
|
|
eps=eps,
|
|
device=config.device)
|
|
else:
|
|
raise ValueError(f"Sampler name {sampler_name} unknown.")
|
|
|
|
return sampling_fn
|
|
|
|
|
|
class Predictor(abc.ABC):
|
|
"""The abstract class for a predictor algorithm."""
|
|
|
|
def __init__(self, sde, score_fn, probability_flow=False):
|
|
super().__init__()
|
|
self.sde = sde
|
|
# Compute the reverse SDE/ODE
|
|
if isinstance(sde, tuple):
|
|
self.rsde = (sde[0].reverse(score_fn, probability_flow), sde[1].reverse(score_fn, probability_flow))
|
|
else:
|
|
self.rsde = sde.reverse(score_fn, probability_flow)
|
|
self.score_fn = score_fn
|
|
|
|
@abc.abstractmethod
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
"""One update of the predictor.
|
|
|
|
Args:
|
|
x: A PyTorch tensor representing the current state.
|
|
t: A PyTorch tensor representing the current time step.
|
|
|
|
Returns:
|
|
x: A PyTorch tensor of the next state.
|
|
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
|
"""
|
|
pass
|
|
|
|
|
|
class Corrector(abc.ABC):
|
|
"""The abstract class for a corrector algorithm."""
|
|
|
|
def __init__(self, sde, score_fn, snr, n_steps):
|
|
super().__init__()
|
|
self.sde = sde
|
|
self.score_fn = score_fn
|
|
self.snr = snr
|
|
self.n_steps = n_steps
|
|
|
|
@abc.abstractmethod
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
"""One update of the corrector.
|
|
|
|
Args:
|
|
x: A PyTorch tensor representing the current state.
|
|
t: A PyTorch tensor representing the current time step.
|
|
|
|
Returns:
|
|
x: A PyTorch tensor of the next state.
|
|
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
|
"""
|
|
pass
|
|
|
|
|
|
@register_predictor(name='euler_maruyama')
|
|
class EulerMaruyamaPredictor(Predictor):
|
|
def __init__(self, sde, score_fn, probability_flow=False):
|
|
super().__init__(sde, score_fn, probability_flow)
|
|
|
|
# def update_fn(self, x, t, *args, **kwargs):
|
|
# dt = -1. / self.rsde.N
|
|
# z = torch.randn_like(x)
|
|
# z = torch.tril(z, -1)
|
|
# z = z + z.transpose(-1, -2)
|
|
# drift, diffusion = self.rsde.sde(x, t, *args, **kwargs)
|
|
# drift = torch.tril(drift, -1)
|
|
# drift = drift + drift.transpose(-1, -2)
|
|
# x_mean = x + drift * dt
|
|
# x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
|
|
# return x, x_mean
|
|
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
dt = -1. / self.rsde.N
|
|
z = torch.randn_like(x)
|
|
# z = torch.tril(z, -1)
|
|
# z = z + z.transpose(-1, -2)
|
|
drift, diffusion = self.rsde.sde(x, t, *args, **kwargs)
|
|
# drift = torch.tril(drift, -1)
|
|
# drift = drift + drift.transpose(-1, -2)
|
|
x_mean = x + drift * dt
|
|
x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
|
|
return x, x_mean
|
|
|
|
|
|
@register_predictor(name='reverse_diffusion')
|
|
class ReverseDiffusionPredictor(Predictor):
|
|
def __init__(self, sde, score_fn, probability_flow=False):
|
|
super().__init__(sde, score_fn, probability_flow)
|
|
|
|
# def update_fn(self, x, t, *args, **kwargs):
|
|
# f, G = self.rsde.discretize(x, t, *args, **kwargs)
|
|
# f = torch.tril(f, -1)
|
|
# f = f + f.transpose(-1, -2)
|
|
# z = torch.randn_like(x)
|
|
# z = torch.tril(z, -1)
|
|
# z = z + z.transpose(-1, -2)
|
|
|
|
# x_mean = x - f
|
|
# x = x_mean + G[:, None, None, None] * z
|
|
# return x, x_mean
|
|
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
f, G = self.rsde.discretize(x, t, *args, **kwargs)
|
|
# f = torch.tril(f, -1)
|
|
# f = f + f.transpose(-1, -2)
|
|
z = torch.randn_like(x)
|
|
# z = torch.tril(z, -1)
|
|
# z = z + z.transpose(-1, -2)
|
|
|
|
x_mean = x - f
|
|
x = x_mean + G[:, None, None] * z
|
|
return x, x_mean
|
|
|
|
|
|
@register_predictor(name='none')
|
|
class NonePredictor(Predictor):
|
|
"""An empty predictor that does nothing."""
|
|
|
|
def __init__(self, sde, score_fn, probability_flow=False):
|
|
pass
|
|
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
return x, x
|
|
|
|
|
|
@register_corrector(name='langevin')
|
|
class LangevinCorrector(Corrector):
|
|
def __init__(self, sde, score_fn, snr, n_steps):
|
|
super().__init__(sde, score_fn, snr, n_steps)
|
|
|
|
# def update_fn(self, x, t, *args, **kwargs):
|
|
# sde = self.sde
|
|
# score_fn = self.score_fn
|
|
# n_steps = self.n_steps
|
|
# target_snr = self.snr
|
|
# if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
|
# timestep = (t * (sde.N - 1) / sde.T).long()
|
|
# # Note: it seems that subVPSDE doesn't set alphas
|
|
# alpha = sde.alphas.to(t.device)[timestep]
|
|
# else:
|
|
# alpha = torch.ones_like(t)
|
|
|
|
# for i in range(n_steps):
|
|
|
|
# grad = score_fn(x, t, *args, **kwargs)
|
|
# noise = torch.randn_like(x)
|
|
|
|
# noise = torch.tril(noise, -1)
|
|
# noise = noise + noise.transpose(-1, -2)
|
|
|
|
# mask = kwargs['mask']
|
|
|
|
# # mask invalid elements and calculate norm
|
|
# mask_tmp = mask.reshape(mask.shape[0], -1)
|
|
|
|
# grad_norm = torch.norm(mask_tmp * grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
|
# noise_norm = torch.norm(mask_tmp * noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
|
|
|
# step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
|
# x_mean = x + step_size[:, None, None, None] * grad
|
|
# x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
|
|
|
|
# return x, x_mean
|
|
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
sde = self.sde
|
|
score_fn = self.score_fn
|
|
n_steps = self.n_steps
|
|
target_snr = self.snr
|
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
|
timestep = (t * (sde.N - 1) / sde.T).long()
|
|
# Note: it seems that subVPSDE doesn't set alphas
|
|
alpha = sde.alphas.to(t.device)[timestep]
|
|
else:
|
|
alpha = torch.ones_like(t)
|
|
|
|
for i in range(n_steps):
|
|
|
|
grad = score_fn(x, t, *args, **kwargs)
|
|
noise = torch.randn_like(x)
|
|
|
|
# noise = torch.tril(noise, -1)
|
|
# noise = noise + noise.transpose(-1, -2)
|
|
|
|
# mask = kwargs['maskX']
|
|
|
|
# mask invalid elements and calculate norm
|
|
# mask_tmp = mask.reshape(mask.shape[0], -1)
|
|
|
|
# grad_norm = torch.norm(mask_tmp * grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
|
# noise_norm = torch.norm(mask_tmp * noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
|
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
|
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
|
|
|
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
|
x_mean = x + step_size[:, None, None] * grad
|
|
x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise
|
|
|
|
return x, x_mean
|
|
|
|
|
|
@register_corrector(name='none')
|
|
class NoneCorrector(Corrector):
|
|
"""An empty corrector that does nothing."""
|
|
|
|
def __init__(self, sde, score_fn, snr, n_steps):
|
|
pass
|
|
|
|
def update_fn(self, x, t, *args, **kwargs):
|
|
return x, x
|
|
|
|
|
|
def shared_predictor_update_fn(x, t, sde, model,
|
|
predictor, probability_flow, continuous, *args, **kwargs):
|
|
"""A wrapper that configures and returns the update function of predictors."""
|
|
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
|
if predictor is None:
|
|
# Corrector-only sampler
|
|
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
|
|
else:
|
|
predictor_obj = predictor(sde, score_fn, probability_flow)
|
|
|
|
return predictor_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
|
|
def shared_corrector_update_fn(x, t, sde, model,
|
|
corrector, continuous, snr, n_steps, *args, **kwargs):
|
|
"""A wrapper that configures and returns the update function of correctors."""
|
|
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
|
|
|
if corrector is None:
|
|
# Predictor-only sampler
|
|
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
|
|
else:
|
|
corrector_obj = corrector(sde, score_fn, snr, n_steps)
|
|
|
|
return corrector_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
|
|
def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
|
|
n_steps=1, probability_flow=False, continuous=False,
|
|
denoise=True, eps=1e-3, device='cuda'):
|
|
"""Create a Predictor-Corrector (PC) sampler.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object representing the forward SDE.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
|
|
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
|
|
inverse_scaler: The inverse data normalizer.
|
|
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
|
|
n_steps: An integer. The number of corrector steps per predictor update.
|
|
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
|
|
continuous: `True` indicates that the score model was continuously trained.
|
|
denoise: If `True`, add one-step denoising to the final samples.
|
|
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
|
device: PyTorch device.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
# Create predictor & corrector update functions
|
|
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
|
sde=sde,
|
|
predictor=predictor,
|
|
probability_flow=probability_flow,
|
|
continuous=continuous)
|
|
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
|
sde=sde,
|
|
corrector=corrector,
|
|
continuous=continuous,
|
|
snr=snr,
|
|
n_steps=n_steps)
|
|
|
|
def pc_sampler(model, n_nodes_pmf):
|
|
"""The PC sampler function.
|
|
|
|
Args:
|
|
model: A score model.
|
|
n_nodes_pmf: Probability mass function of graph nodes.
|
|
|
|
Returns:
|
|
Samples, number of function evaluations.
|
|
"""
|
|
with torch.no_grad():
|
|
# Initial sample
|
|
x = sde.prior_sampling(shape).to(device)
|
|
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
|
|
|
# Sample the number of nodes
|
|
n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
|
|
mask = torch.zeros((shape[0], shape[-1]), device=device)
|
|
for i in range(shape[0]):
|
|
mask[i][:n_nodes[i]] = 1.
|
|
mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
|
|
mask = torch.tril(mask, -1)
|
|
mask = mask + mask.transpose(-1, -2)
|
|
|
|
x = x * mask
|
|
|
|
for i in range(sde.N):
|
|
t = timesteps[i]
|
|
vec_t = torch.ones(shape[0], device=t.device) * t
|
|
x, x_mean = corrector_update_fn(x, vec_t, model=model, mask=mask)
|
|
x = x * mask
|
|
x, x_mean = predictor_update_fn(x, vec_t, model=model, mask=mask)
|
|
x = x * mask
|
|
|
|
return inverse_scaler(x_mean if denoise else x) * mask, sde.N * (n_steps + 1), n_nodes
|
|
|
|
return pc_sampler
|
|
|
|
def get_pc_sampler_nas(sde, shape, predictor, corrector, inverse_scaler, snr,
|
|
n_steps=1, probability_flow=False, continuous=False,
|
|
denoise=True, eps=1e-3, device='cuda'):
|
|
"""Create a Predictor-Corrector (PC) sampler.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object representing the forward SDE.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
|
|
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
|
|
inverse_scaler: The inverse data normalizer.
|
|
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
|
|
n_steps: An integer. The number of corrector steps per predictor update.
|
|
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
|
|
continuous: `True` indicates that the score model was continuously trained.
|
|
denoise: If `True`, add one-step denoising to the final samples.
|
|
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
|
device: PyTorch device.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
# Create predictor & corrector update functions
|
|
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
|
sde=sde,
|
|
predictor=predictor,
|
|
probability_flow=probability_flow,
|
|
continuous=continuous)
|
|
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
|
sde=sde,
|
|
corrector=corrector,
|
|
continuous=continuous,
|
|
snr=snr,
|
|
n_steps=n_steps)
|
|
|
|
def pc_sampler(model, mask):
|
|
"""The PC sampler function.
|
|
|
|
Args:
|
|
model: A score model.
|
|
n_nodes_pmf: Probability mass function of graph nodes.
|
|
|
|
Returns:
|
|
Samples, number of function evaluations.
|
|
"""
|
|
with torch.no_grad():
|
|
# Initial sample
|
|
x = sde.prior_sampling(shape).to(device)
|
|
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
|
|
|
# Sample the number of nodes
|
|
# n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
|
|
# mask = torch.zeros((shape[0], shape[-1]), device=device)
|
|
# for i in range(shape[0]):
|
|
# mask[i][:n_nodes[i]] = 1.
|
|
# mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
|
|
# mask = torch.tril(mask, -1)
|
|
# mask = mask + mask.transpose(-1, -2)
|
|
# x = x * mask
|
|
mask = mask[0].unsqueeze(0).repeat(x.size(0), 1, 1)
|
|
|
|
for i in trange(sde.N, desc='[PC sampling]', position=1, leave=False):
|
|
t = timesteps[i]
|
|
vec_t = torch.ones(shape[0], device=t.device) * t
|
|
x, x_mean = corrector_update_fn(x, vec_t, model=model, maskX=mask)
|
|
# x = x * mask
|
|
x, x_mean = predictor_update_fn(x, vec_t, model=model, maskX=mask)
|
|
# x = x * mask
|
|
|
|
# return inverse_scaler(x_mean if denoise else x) * mask, sde.N * (n_steps + 1), n_nodes
|
|
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1), None
|
|
|
|
return pc_sampler
|
|
|
|
|
|
def get_pc_conditional_sampler_nas(sde, shape,
|
|
predictor, corrector, inverse_scaler, snr,
|
|
n_steps=1, probability_flow=False,
|
|
continuous=False, denoise=True, eps=1e-5, device='cuda',
|
|
regress=True, labels='max', classifier_scale=0.5,
|
|
weight_scheduling=True, weight_ratio=True, t_spot=1., t_spot_end=None,
|
|
p=1, prod_w=False, weight_ratio_abs=False):
|
|
"""Class-conditional sampling with Predictor-Corrector (PC) samplers.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
score_model: A `torch.nn.Module` object that represents the architecture of the score-based model.
|
|
classifier: A `torch.nn.Module` object that represents the architecture of the noise-dependent classifier.
|
|
# classifier_params: A dictionary that contains the weights of the classifier.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
predictor: A subclass of `sampling.predictor` that represents a predictor algorithm.
|
|
corrector: A subclass of `sampling.corrector` that represents a corrector algorithm.
|
|
inverse_scaler: The inverse data normalizer.
|
|
snr: A `float` number. The signal-to-noise ratio for correctors.
|
|
n_steps: An integer. The number of corrector steps per update of the predictor.
|
|
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
|
|
continuous: `True` indicates the score-based model was trained with continuous time.
|
|
denoise: If `True`, add one-step denoising to final samples.
|
|
eps: A `float` number. The SDE/ODE will be integrated to `eps` to avoid numerical issues.
|
|
|
|
Returns: A pmapped class-conditional image sampler.
|
|
"""
|
|
score_grad_norm_p, classifier_grad_norm_p = [], []
|
|
score_grad_norm_c, classifier_grad_norm_c = [], []
|
|
if t_spot_end is None or t_spot_end == 0.:
|
|
t_spot_end = eps
|
|
|
|
def weight_scheduling_fn(w, t):
|
|
return w * 0.1 ** t
|
|
|
|
def conditional_predictor_update_fn(score_model, classifier, x, t, labels, maskX, *args, **kwargs):
|
|
"""The predictor update function for class-conditional sampling."""
|
|
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
|
# The gradient function of the noise-dependent classifier
|
|
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
|
regress=regress, labels=labels)
|
|
|
|
def total_grad_fn(x, t, *args, **kwargs):
|
|
|
|
# score = score_fn(x, t, *args, **kwargs)
|
|
score = score_fn(x, t, maskX)
|
|
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
|
|
|
# Sample weight
|
|
if weight_scheduling:
|
|
w = weight_scheduling_fn(classifier_scale, t[0].item())
|
|
else:
|
|
w = classifier_scale
|
|
|
|
if weight_ratio:
|
|
if prod_w:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / (w * classifier_grad).view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
else:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
# ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
w *= ratio[:, None, None]
|
|
|
|
if weight_ratio_abs:
|
|
assert not weight_ratio
|
|
ratio = torch.div(torch.abs(score), torch.abs(classifier_grad))
|
|
w *= ratio
|
|
|
|
score_grad_norm_p.append(torch.mean(score.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if weight_ratio: # ratio per sample
|
|
classifier_grad_norm_p.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1) * ratio[:, None, None]).item())
|
|
elif weight_ratio_abs: # ratio per element
|
|
classifier_grad_norm_p.append(torch.mean((classifier_grad * ratio).norm(p=p)).item())
|
|
else:
|
|
classifier_grad_norm_p.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if t_spot < 1.:
|
|
if t[0].item() <= t_spot and t[0] >= t_spot_end:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
else:
|
|
return score
|
|
else:
|
|
# return (1 - w) * score + w * classifier_grad
|
|
return score + w * classifier_grad
|
|
|
|
if predictor is None:
|
|
predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
|
|
else:
|
|
predictor_obj = predictor(sde, total_grad_fn, probability_flow)
|
|
return predictor_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
def conditional_corrector_update_fn(score_model, classifier, x, t, labels, maskX, *args, **kwargs):
|
|
"""The corrector update function for class-conditional sampling."""
|
|
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
|
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
|
regress=regress, labels=labels)
|
|
|
|
def total_grad_fn(x, t, *args, **kwargs):
|
|
# score = score_fn(x, t, *args, **kwargs)
|
|
score = score_fn(x, t, maskX)
|
|
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
|
|
|
# Sample weight
|
|
if weight_scheduling:
|
|
w = weight_scheduling_fn(classifier_scale, t[0].item())
|
|
else:
|
|
w = classifier_scale
|
|
|
|
if weight_ratio:
|
|
if prod_w:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / (w * classifier_grad).view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
else:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
# ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
w *= ratio[:, None, None]
|
|
|
|
score_grad_norm_c.append(torch.mean(score.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if weight_ratio:
|
|
classifier_grad_norm_c.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1) * ratio[:, None, None]).item())
|
|
else:
|
|
classifier_grad_norm_c.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if t_spot < 1.:
|
|
if t[0].item() <= t_spot and t[0] >= t_spot_end:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
else:
|
|
return score
|
|
else:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
|
|
if corrector is None:
|
|
corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
|
|
else:
|
|
corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
|
|
return corrector_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
def pc_conditional_sampler(score_model, mask, classifier,
|
|
eval_chain=False, keep_chain=None, number_chain_steps=None):
|
|
"""Generate class-conditional samples with Predictor-Corrector (PC) samplers.
|
|
|
|
Args:
|
|
score_model: A `torch.nn.Module` object that represents the training state
|
|
of the score-based model.
|
|
labels: A JAX array of integers that represent the target label of each sample.
|
|
|
|
Returns:
|
|
Class-conditional samples.
|
|
"""
|
|
chain_x = None
|
|
if eval_chain:
|
|
if number_chain_steps is None:
|
|
number_chain_steps = sde.N
|
|
if keep_chain is None:
|
|
keep_chain = shape[0] # all sample
|
|
assert number_chain_steps <= sde.N
|
|
chain_x_size = torch.Size((number_chain_steps, keep_chain, *shape[1:]))
|
|
chain_x = torch.zeros(chain_x_size)
|
|
|
|
with torch.no_grad():
|
|
# Initial sample
|
|
x = sde.prior_sampling(shape).to(device)
|
|
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
|
|
|
if len(mask.shape) == 3:
|
|
mask = mask[0]
|
|
mask = mask.unsqueeze(0).repeat(x.size(0), 1, 1) # adj
|
|
|
|
for i in trange(sde.N, desc='[PC conditional sampling]', position=1, leave=False):
|
|
t = timesteps[i]
|
|
vec_t = torch.ones(shape[0], device=t.device) * t
|
|
# x, x_mean = conditional_corrector_update_fn(x, vec_t, model=model, maskX=mask)
|
|
x, x_mean = conditional_corrector_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask)
|
|
# x = x * mask
|
|
x, x_mean = conditional_predictor_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask)
|
|
# x = x * mask
|
|
|
|
if eval_chain:
|
|
# arch_metric = sampling_metrics(arch_list=inverse_scaler(x_mean if denoise else x),
|
|
# adj=adj, mask=mask,
|
|
# this_sample_dir=os.path.join(sampling_metrics.exp_name),
|
|
# test=False)
|
|
# r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
|
|
# Save the first keep_chain graphs
|
|
write_index = number_chain_steps - 1 - int((i * number_chain_steps) // sde.N)
|
|
# write_index = int((t * number_chain_steps) // sde.T)
|
|
chain_x[write_index] = inverse_scaler(x_mean if denoise else x)[:keep_chain]
|
|
|
|
# Overwrite last frame with the resulting x
|
|
# if keep_chain > 0:
|
|
# final_x_chain = inverse_scaler(x_mean if denoise else x)[:keep_chain]
|
|
# chain_x[0] = final_x_chain
|
|
# # Repeat last frame to see final sample better
|
|
# import pdb; pdb.set_trace()
|
|
# chain_x = torch.cat([chain_x, chain_x[-1:].repeat(10, 1, 1)], dim=0)
|
|
# import pdb; pdb.set_trace()
|
|
# assert chain_x.size(0) == (number_chain_steps + 10)
|
|
|
|
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1), chain_x, (score_grad_norm_p, classifier_grad_norm_p, score_grad_norm_c, classifier_grad_norm_c)
|
|
|
|
return pc_conditional_sampler
|
|
|
|
|
|
def get_pc_conditional_sampler_meta_nas(sde, shape,
|
|
predictor, corrector, inverse_scaler, snr,
|
|
n_steps=1, probability_flow=False,
|
|
continuous=False, denoise=True, eps=1e-5, device='cuda',
|
|
regress=True, labels='max', classifier_scale=0.5,
|
|
weight_scheduling=True, weight_ratio=True, t_spot=1., t_spot_end=None,
|
|
p=1, prod_w=False, weight_ratio_abs=False,
|
|
data_name='cifar10', num_sample=20, search_space=None):
|
|
"""Class-conditional sampling with Predictor-Corrector (PC) samplers.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
score_model: A `torch.nn.Module` object that represents the architecture of the score-based model.
|
|
classifier: A `torch.nn.Module` object that represents the architecture of the noise-dependent classifier.
|
|
# classifier_params: A dictionary that contains the weights of the classifier.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
predictor: A subclass of `sampling.predictor` that represents a predictor algorithm.
|
|
corrector: A subclass of `sampling.corrector` that represents a corrector algorithm.
|
|
inverse_scaler: The inverse data normalizer.
|
|
snr: A `float` number. The signal-to-noise ratio for correctors.
|
|
n_steps: An integer. The number of corrector steps per update of the predictor.
|
|
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
|
|
continuous: `True` indicates the score-based model was trained with continuous time.
|
|
denoise: If `True`, add one-step denoising to final samples.
|
|
eps: A `float` number. The SDE/ODE will be integrated to `eps` to avoid numerical issues.
|
|
|
|
Returns: A pmapped class-conditional image sampler.
|
|
"""
|
|
|
|
# --------- Meta-NAS (START) ---------- #
|
|
test_dataset = MetaTestDataset(
|
|
data_path=PROCESSED_DATA_PATH,
|
|
data_name=data_name,
|
|
num_sample=num_sample
|
|
)
|
|
# --------- Meta-NAS (END) ---------- #
|
|
|
|
score_grad_norm_p, classifier_grad_norm_p = [], []
|
|
score_grad_norm_c, classifier_grad_norm_c = [], []
|
|
if t_spot_end is None or t_spot_end == 0.:
|
|
t_spot_end = eps
|
|
|
|
def weight_scheduling_fn(w, t):
|
|
return w * 0.1 ** t
|
|
|
|
def conditional_predictor_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
|
|
"""The predictor update function for class-conditional sampling."""
|
|
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
|
# The gradient function of the noise-dependent classifier
|
|
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
|
regress=regress, labels=labels)
|
|
|
|
def total_grad_fn(x, t, *args, **kwargs):
|
|
|
|
# score = score_fn(x, t, *args, **kwargs)
|
|
score = score_fn(x, t, maskX)
|
|
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
|
|
|
# Sample weight
|
|
if weight_scheduling:
|
|
w = weight_scheduling_fn(classifier_scale, t[0].item())
|
|
else:
|
|
w = classifier_scale
|
|
|
|
if weight_ratio:
|
|
if prod_w:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / (w * classifier_grad).view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
else:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
# ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
w *= ratio[:, None, None]
|
|
|
|
if weight_ratio_abs:
|
|
assert not weight_ratio
|
|
ratio = torch.div(torch.abs(score), torch.abs(classifier_grad))
|
|
w *= ratio
|
|
|
|
score_grad_norm_p.append(torch.mean(score.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if weight_ratio: # ratio per sample
|
|
classifier_grad_norm_p.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1) * ratio[:, None, None]).item())
|
|
elif weight_ratio_abs: # ratio per element
|
|
classifier_grad_norm_p.append(torch.mean((classifier_grad * ratio).norm(p=p)).item())
|
|
else:
|
|
classifier_grad_norm_p.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if t_spot < 1.:
|
|
if t[0].item() <= t_spot and t[0] >= t_spot_end:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
else:
|
|
return score
|
|
else:
|
|
# return (1 - w) * score + w * classifier_grad
|
|
return score + w * classifier_grad
|
|
|
|
if predictor is None:
|
|
predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
|
|
else:
|
|
predictor_obj = predictor(sde, total_grad_fn, probability_flow)
|
|
return predictor_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
def conditional_corrector_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
|
|
"""The corrector update function for class-conditional sampling."""
|
|
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
|
|
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
|
|
regress=regress, labels=labels)
|
|
|
|
def total_grad_fn(x, t, *args, **kwargs):
|
|
# score = score_fn(x, t, *args, **kwargs)
|
|
score = score_fn(x, t, maskX)
|
|
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
|
|
|
|
# Sample weight
|
|
if weight_scheduling:
|
|
w = weight_scheduling_fn(classifier_scale, t[0].item())
|
|
else:
|
|
w = classifier_scale
|
|
|
|
if weight_ratio:
|
|
if prod_w:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / (w * classifier_grad).view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
else:
|
|
ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
# ratio = score.view(x.shape[0], -1).norm(p=p, dim=-1) / classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)
|
|
w *= ratio[:, None, None]
|
|
|
|
score_grad_norm_c.append(torch.mean(score.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if weight_ratio:
|
|
classifier_grad_norm_c.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1) * ratio[:, None, None]).item())
|
|
else:
|
|
classifier_grad_norm_c.append(torch.mean(classifier_grad.view(x.shape[0], -1).norm(p=p, dim=-1)).item())
|
|
|
|
if t_spot < 1.:
|
|
if t[0].item() <= t_spot and t[0] >= t_spot_end:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
else:
|
|
return score
|
|
else:
|
|
return score + w * classifier_grad
|
|
# return (1 - w) * score + w * classifier_grad
|
|
if corrector is None:
|
|
corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
|
|
else:
|
|
corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
|
|
return corrector_obj.update_fn(x, t, *args, **kwargs)
|
|
|
|
def pc_conditional_sampler(score_model, mask, classifier,
|
|
eval_chain=False, keep_chain=None,
|
|
number_chain_steps=None, classifier_scale=None,
|
|
task=None, sample_bs=None):
|
|
"""Generate class-conditional samples with Predictor-Corrector (PC) samplers.
|
|
|
|
Args:
|
|
score_model: A `torch.nn.Module` object that represents the training state
|
|
of the score-based model.
|
|
labels: A JAX array of integers that represent the target label of each sample.
|
|
|
|
Returns:
|
|
Class-conditional samples.
|
|
"""
|
|
|
|
chain_x = None
|
|
if eval_chain:
|
|
if number_chain_steps is None:
|
|
number_chain_steps = sde.N
|
|
if keep_chain is None:
|
|
keep_chain = shape[0] # all sample
|
|
assert number_chain_steps <= sde.N
|
|
chain_x_size = torch.Size((number_chain_steps, keep_chain, *shape[1:]))
|
|
chain_x = torch.zeros(chain_x_size)
|
|
|
|
with torch.no_grad():
|
|
|
|
# ----------- Meta-NAS (START) ---------- #
|
|
# different task embedding in a batch
|
|
# task_batch = []
|
|
# for _ in range(shape[0]):
|
|
# task_batch.append(test_dataset[0])
|
|
# task = torch.stack(task_batch, dim=0)
|
|
|
|
if task is None:
|
|
# same task embedding in a batch
|
|
task = test_dataset[0]
|
|
task = task.repeat(shape[0], 1, 1)
|
|
task = task.to(device)
|
|
else:
|
|
task = task.repeat(shape[0], 1, 1)
|
|
task = task.to(device)
|
|
# print(f'Sampling stage')
|
|
# import pdb; pdb.set_trace()
|
|
|
|
# for accerlerating sampling
|
|
classifier.sample_state = True
|
|
classifier.D_mu = None
|
|
# ----------- Meta-NAS (END) ---------- #
|
|
# import pdb; pdb.set_trace()
|
|
# Initial sample
|
|
x = sde.prior_sampling(shape).to(device)
|
|
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
|
|
|
if len(mask.shape) == 3:
|
|
mask = mask[0]
|
|
mask = mask.unsqueeze(0).repeat(x.size(0), 1, 1) # adj
|
|
|
|
for i in trange(sde.N, desc='[PC conditional sampling]', position=1, leave=False):
|
|
t = timesteps[i]
|
|
vec_t = torch.ones(shape[0], device=t.device) * t
|
|
|
|
x, x_mean = conditional_corrector_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
|
|
x, x_mean = conditional_predictor_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
|
|
|
|
if eval_chain:
|
|
# Save the first keep_chain graphs
|
|
write_index = number_chain_steps - 1 - int((i * number_chain_steps) // sde.N)
|
|
# write_index = int((t * number_chain_steps) // sde.T)
|
|
chain_x[write_index] = inverse_scaler(x_mean if denoise else x)[:keep_chain]
|
|
|
|
classifier.sample_state = False
|
|
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1), chain_x, (score_grad_norm_p, classifier_grad_norm_p, score_grad_norm_c, classifier_grad_norm_c)
|
|
|
|
return pc_conditional_sampler
|
|
|
|
|
|
|
|
def get_ode_sampler(sde, shape, inverse_scaler, denoise=False,
|
|
rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3, device='cuda'):
|
|
"""Probability flow ODE sampler with the black-box ODE solver.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
inverse_scaler: The inverse data normalizer.
|
|
denoise: If `True`, add one-step denoising to final samples.
|
|
rtol: A `float` number. The relative tolerance level of the ODE solver.
|
|
atol: A `float` number. The absolute tolerance level of the ODE solver.
|
|
method: A `str`. The algorithm used for the black-box ODE solver.
|
|
See the documentation of `scipy.integrate.solve_ivp`.
|
|
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
|
device: PyTorch device.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
|
|
def denoise_update_fn(model, x, mask):
|
|
score_fn = get_score_fn(sde, model, train=False, continuous=True)
|
|
# Reverse diffusion predictor for denoising
|
|
predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
|
|
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
|
|
_, x = predictor_obj.update_fn(x, vec_eps, mask=mask)
|
|
return x
|
|
|
|
def drift_fn(model, x, t, mask):
|
|
"""Get the drift function of the reverse-time SDE."""
|
|
score_fn = get_score_fn(sde, model, train=False, continuous=True)
|
|
rsde = sde.reverse(score_fn, probability_flow=True)
|
|
return rsde.sde(x, t, mask=mask)[0]
|
|
|
|
def ode_sampler(model, n_nodes_pmf, z=None):
|
|
"""The probability flow ODE sampler with black-box ODE solver.
|
|
|
|
Args:
|
|
model: A score model.
|
|
n_nodes_pmf: Probability mass function of graph nodes.
|
|
z: If present, generate samples from latent code `z`.
|
|
Returns:
|
|
samples, number of function evaluations.
|
|
"""
|
|
with torch.no_grad():
|
|
# Initial sample
|
|
if z is None:
|
|
# If not represent, sample the latent code from the prior distribution of the SDE.
|
|
x = sde.prior_sampling(shape).to(device)
|
|
else:
|
|
x = z
|
|
|
|
# Sample the number of nodes
|
|
n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
|
|
mask = torch.zeros((shape[0], shape[-1]), device=device)
|
|
for i in range(shape[0]):
|
|
mask[i][:n_nodes[i]] = 1.
|
|
mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
|
|
|
|
def ode_func(t, x):
|
|
x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
|
|
vec_t = torch.ones(shape[0], device=x.device) * t
|
|
drift = drift_fn(model, x, vec_t, mask)
|
|
return to_flattened_numpy(drift)
|
|
|
|
# Black-box ODE solver for the probability flow ODE
|
|
solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),
|
|
rtol=rtol, atol=atol, method=method)
|
|
nfe = solution.nfev
|
|
x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32)
|
|
|
|
# Denoising is equivalent to running one predictor step without adding noise
|
|
if denoise:
|
|
x = denoise_update_fn(model, x, mask)
|
|
|
|
x = inverse_scaler(x) * mask
|
|
return x, nfe, n_nodes
|
|
|
|
return ode_sampler
|
|
|
|
|
|
def get_diffeq_sampler(sde, shape, inverse_scaler, denoise=False,
|
|
rtol=1e-5, atol=1e-5, step_size=0.01, method='dopri5', eps=1e-3, device='cuda'):
|
|
"""
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
shape: A sequence of integers. The expected shape of a single sample.
|
|
inverse_scaler: The inverse data normalizer.
|
|
denoise: If `True`, add one-step denoising to final samples.
|
|
rtol: A `float` number. The relative tolerance level of the ODE solver.
|
|
atol: A `float` number. The absolute tolerance level of the ODE solver.
|
|
method: A `str`. The algorithm used for the black-box ODE solver in torchdiffeq.
|
|
See the documentation of `torchdiffeq`. eg: adaptive solver('dopri5', 'bosh3', 'fehlberg2')
|
|
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
|
device: PyTorch device.
|
|
|
|
Returns:
|
|
A sampling function that returns samples and the number of function evaluations during sampling.
|
|
"""
|
|
|
|
def denoise_update_fn(model, x, mask):
|
|
score_fn = get_score_fn(sde, model, train=False, continuous=True)
|
|
# Reverse diffusion predictor for denoising
|
|
predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
|
|
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
|
|
_, x = predictor_obj.update_fn(x, vec_eps, mask=mask)
|
|
return x
|
|
|
|
def drift_fn(model, x, t, mask):
|
|
"""Get the drift function of the reverse-time SDE."""
|
|
score_fn = get_score_fn(sde, model, train=False, continuous=True)
|
|
rsde = sde.reverse(score_fn, probability_flow=True)
|
|
return rsde.sde(x, t, mask=mask)[0]
|
|
|
|
def diffeq_sampler(model, n_nodes_pmf, z=None):
|
|
"""The probability flow ODE sampler with ODE solver from torchdiffeq.
|
|
|
|
Args:
|
|
model: A score model.
|
|
n_nodes_pmf: Probability mass function of graph nodes.
|
|
z: If present, generate samples from latent code `z`.
|
|
Returns:
|
|
samples, number of function evaluations.
|
|
"""
|
|
with torch.no_grad():
|
|
# initial sample
|
|
if z is None:
|
|
# If not represent, sample the latent code from the prior distribution of the SDE.
|
|
x = sde.prior_sampling(shape).to(device)
|
|
else:
|
|
x = z
|
|
|
|
# Sample the number of nodes
|
|
n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
|
|
mask = torch.zeros((shape[0], shape[-1]), device=device)
|
|
for i in range(shape[0]):
|
|
mask[i][:n_nodes[i]] = 1.
|
|
mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
|
|
|
|
class ODEfunc(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ODEfunc, self).__init__()
|
|
self.nfe = 0
|
|
|
|
def forward(self, t, x):
|
|
self.nfe += 1
|
|
x = x.reshape(shape)
|
|
vec_t = torch.ones(shape[0], device=x.device) * t
|
|
drift = drift_fn(model, x, vec_t, mask)
|
|
return drift.reshape((-1,))
|
|
|
|
# Black-box ODE solver for the probability flow ODE
|
|
ode_func = ODEfunc()
|
|
if method in ['dopri5', 'bosh3', 'fehlberg2']:
|
|
solution = odeint(ode_func, x.reshape((-1,)), torch.tensor([sde.T, eps], device=x.device),
|
|
rtol=rtol, atol=atol, method=method,
|
|
options={'step_t': torch.tensor([1e-3], device=x.device)})
|
|
elif method in ['euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams']:
|
|
solution = odeint(ode_func, x.reshape((-1,)), torch.tensor([sde.T, eps], device=x.device),
|
|
rtol=rtol, atol=atol, method=method,
|
|
options={'step_size': step_size})
|
|
|
|
x = solution[-1, :].reshape(shape)
|
|
|
|
# Denoising is equivalent to running one predictor step without adding noise
|
|
if denoise:
|
|
x = denoise_update_fn(model, x, mask)
|
|
|
|
x = inverse_scaler(x) * mask
|
|
return x, ode_func.nfe, n_nodes
|
|
|
|
return diffeq_sampler
|