302 lines
11 KiB
Python
302 lines
11 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import sde_lib
|
|
import numpy as np
|
|
|
|
_MODELS = {}
|
|
|
|
|
|
def register_model(cls=None, *, name=None):
|
|
"""A decorator for registering model classes."""
|
|
|
|
def _register(cls):
|
|
if name is None:
|
|
local_name = cls.__name__
|
|
else:
|
|
local_name = name
|
|
if local_name in _MODELS:
|
|
raise ValueError(
|
|
f'Already registered model with name: {local_name}')
|
|
_MODELS[local_name] = cls
|
|
return cls
|
|
|
|
if cls is None:
|
|
return _register
|
|
else:
|
|
return _register(cls)
|
|
|
|
|
|
def get_model(name):
|
|
return _MODELS[name]
|
|
|
|
|
|
def create_model(config):
|
|
"""Create the score model."""
|
|
model_name = config.model.name
|
|
score_model = get_model(model_name)(config)
|
|
score_model = score_model.to(config.device)
|
|
if 'load_pretrained' in config['training'].keys() and config.training.load_pretrained:
|
|
from utils import restore_checkpoint_partial
|
|
score_model = restore_checkpoint_partial(score_model, torch.load(config.training.pretrained_model_path, map_location=config.device)['model'])
|
|
# score_model = torch.nn.DataParallel(score_model)
|
|
return score_model
|
|
|
|
|
|
def get_model_fn(model, train=False):
|
|
"""Create a function to give the output of the score-based model.
|
|
|
|
Args:
|
|
model: The score model.
|
|
train: `True` for training and `False` for evaluation.
|
|
|
|
Returns:
|
|
A model function.
|
|
"""
|
|
|
|
def model_fn(x, labels, *args, **kwargs):
|
|
"""Compute the output of the score-based model.
|
|
|
|
Args:
|
|
x: A mini-batch of input data (Adjacency matrices).
|
|
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
|
for different models.
|
|
mask: Mask for adjacency matrices.
|
|
|
|
Returns:
|
|
A tuple of (model output, new mutable states)
|
|
"""
|
|
if not train:
|
|
model.eval()
|
|
return model(x, labels, *args, **kwargs)
|
|
else:
|
|
model.train()
|
|
return model(x, labels, *args, **kwargs)
|
|
|
|
return model_fn
|
|
|
|
|
|
def get_score_fn(sde, model, train=False, continuous=False):
|
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
model: A score model.
|
|
train: `True` for training and `False` for evaluation.
|
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
|
|
|
Returns:
|
|
A score function.
|
|
"""
|
|
model_fn = get_model_fn(model, train=train)
|
|
|
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
|
def score_fn(x, t, *args, **kwargs):
|
|
# Scale neural network output by standard deviation and flip sign
|
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
|
labels = t * 999
|
|
score = model_fn(x, labels, *args, **kwargs)
|
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
labels = t * (sde.N - 1)
|
|
score = model_fn(x, labels, *args, **kwargs)
|
|
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
|
labels.long()]
|
|
|
|
score = -score / std[:, None, None]
|
|
return score
|
|
|
|
elif isinstance(sde, sde_lib.VESDE):
|
|
def score_fn(x, t, *args, **kwargs):
|
|
if continuous:
|
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
|
labels = sde.T - t
|
|
labels *= sde.N - 1
|
|
labels = torch.round(labels).long()
|
|
|
|
score = model_fn(x, labels, *args, **kwargs)
|
|
return score
|
|
|
|
else:
|
|
raise NotImplementedError(
|
|
f"SDE class {sde.__class__.__name__} not yet supported.")
|
|
|
|
return score_fn
|
|
|
|
|
|
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
|
regress=True, labels='max'):
|
|
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
|
|
|
def classifier_grad_fn(x, t, *args, **kwargs):
|
|
with torch.enable_grad():
|
|
x_in = x.detach().requires_grad_(True)
|
|
if regress:
|
|
assert labels in ['max', 'min']
|
|
logit = logit_fn(x_in, t, *args, **kwargs)
|
|
prob = logit.sum()
|
|
else:
|
|
logit = logit_fn(x_in, t, *args, **kwargs)
|
|
# prob = torch.nn.functional.log_softmax(logit, dim=-1)[torch.arange(labels.shape[0]), labels].sum()
|
|
log_prob = F.log_softmax(logit, dim=-1)
|
|
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
|
# prob.backward()
|
|
# classifier_grad = x_in.grad
|
|
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
|
return classifier_grad
|
|
|
|
return classifier_grad_fn
|
|
|
|
|
|
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
|
classifier_fn = get_model_fn(classifier, train=train)
|
|
|
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
|
def logit_fn(x, t, *args, **kwargs):
|
|
# Scale neural network output by standard deviation and flip sign
|
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
|
labels = t * 999
|
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
|
# std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
labels = t * (sde.N - 1)
|
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
|
# std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
|
# labels.long()]
|
|
|
|
# score = -score / std[:, None, None]
|
|
return logit
|
|
|
|
elif isinstance(sde, sde_lib.VESDE):
|
|
def logit_fn(x, t, *args, **kwargs):
|
|
if continuous:
|
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
|
labels = sde.T - t
|
|
labels *= sde.N - 1
|
|
labels = torch.round(labels).long()
|
|
|
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
|
return logit
|
|
|
|
return logit_fn
|
|
|
|
|
|
def get_predictor_fn(sde, model, train=False, continuous=False):
|
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
|
|
|
Args:
|
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
|
model: A predictor model.
|
|
train: `True` for training and `False` for evaluation.
|
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
|
|
|
Returns:
|
|
A score function.
|
|
"""
|
|
model_fn = get_model_fn(model, train=train)
|
|
|
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
|
def predictor_fn(x, t, *args, **kwargs):
|
|
# Scale neural network output by standard deviation and flip sign
|
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
|
labels = t * 999
|
|
pred = model_fn(x, labels, *args, **kwargs)
|
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
|
labels = t * (sde.N - 1)
|
|
pred = model_fn(x, labels, *args, **kwargs)
|
|
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
|
labels.long()]
|
|
|
|
# score = -score / std[:, None, None]
|
|
return pred
|
|
|
|
elif isinstance(sde, sde_lib.VESDE):
|
|
def predictor_fn(x, t, *args, **kwargs):
|
|
if continuous:
|
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
|
else:
|
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
|
labels = sde.T - t
|
|
labels *= sde.N - 1
|
|
labels = torch.round(labels).long()
|
|
|
|
pred = model_fn(x, labels, *args, **kwargs)
|
|
return pred
|
|
|
|
else:
|
|
raise NotImplementedError(
|
|
f"SDE class {sde.__class__.__name__} not yet supported.")
|
|
|
|
return predictor_fn
|
|
|
|
|
|
def to_flattened_numpy(x):
|
|
"""Flatten a torch tensor `x` and convert it to numpy."""
|
|
return x.detach().cpu().numpy().reshape((-1,))
|
|
|
|
|
|
def from_flattened_numpy(x, shape):
|
|
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
|
return torch.from_numpy(x.reshape(shape))
|
|
|
|
|
|
@torch.no_grad()
|
|
def mask_adj2node(adj_mask):
|
|
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
|
|
|
Args:
|
|
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
|
|
|
Output:
|
|
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
|
"""
|
|
|
|
batch_size, max_num_nodes, _ = adj_mask.shape
|
|
|
|
node_mask = adj_mask[:, 0, :].clone()
|
|
node_mask[:, 0] = 1
|
|
|
|
return node_mask
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_rw_feat(k_step, dense_adj):
|
|
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
|
|
|
rw_list = []
|
|
deg = dense_adj.sum(-1, keepdims=True)
|
|
AD = dense_adj / (deg + 1e-8)
|
|
rw_list.append(AD)
|
|
|
|
for _ in range(k_step):
|
|
rw = torch.bmm(rw_list[-1], AD)
|
|
rw_list.append(rw)
|
|
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
|
|
|
rw_landing = torch.diagonal(
|
|
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
|
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
|
|
|
# get the shortest path distance indices
|
|
tmp_rw = rw_map.sort(dim=1)[0]
|
|
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
|
|
|
spd_onehot = torch.nn.functional.one_hot(
|
|
spd_ind, num_classes=k_step+1).to(torch.float)
|
|
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
|
|
|
return rw_landing, spd_onehot
|