290 lines
9.9 KiB
Python
290 lines
9.9 KiB
Python
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import sde_lib
|
||
|
|
||
|
_MODELS = {}
|
||
|
|
||
|
|
||
|
def register_model(cls=None, *, name=None):
|
||
|
"""A decorator for registering model classes."""
|
||
|
|
||
|
def _register(cls):
|
||
|
if name is None:
|
||
|
local_name = cls.__name__
|
||
|
else:
|
||
|
local_name = name
|
||
|
if local_name in _MODELS:
|
||
|
raise ValueError(
|
||
|
f'Already registered model with name: {local_name}')
|
||
|
_MODELS[local_name] = cls
|
||
|
return cls
|
||
|
|
||
|
if cls is None:
|
||
|
return _register
|
||
|
else:
|
||
|
return _register(cls)
|
||
|
|
||
|
|
||
|
def get_model(name):
|
||
|
return _MODELS[name]
|
||
|
|
||
|
|
||
|
def create_model(config):
|
||
|
"""Create the model."""
|
||
|
model_name = config.model.name
|
||
|
model = get_model(model_name)(config)
|
||
|
model = model.to(config.device)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def get_model_fn(model, train=False):
|
||
|
"""Create a function to give the output of the score-based model.
|
||
|
|
||
|
Args:
|
||
|
model: The score model.
|
||
|
train: `True` for training and `False` for evaluation.
|
||
|
|
||
|
Returns:
|
||
|
A model function.
|
||
|
"""
|
||
|
|
||
|
def model_fn(x, labels, *args, **kwargs):
|
||
|
"""Compute the output of the score-based model.
|
||
|
|
||
|
Args:
|
||
|
x: A mini-batch of input data (Adjacency matrices).
|
||
|
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||
|
for different models.
|
||
|
mask: Mask for adjacency matrices.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of (model output, new mutable states)
|
||
|
"""
|
||
|
if not train:
|
||
|
model.eval()
|
||
|
return model(x, labels, *args, **kwargs)
|
||
|
else:
|
||
|
model.train()
|
||
|
return model(x, labels, *args, **kwargs)
|
||
|
|
||
|
return model_fn
|
||
|
|
||
|
|
||
|
def get_score_fn(sde, model, train=False, continuous=False):
|
||
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||
|
|
||
|
Args:
|
||
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||
|
model: A score model.
|
||
|
train: `True` for training and `False` for evaluation.
|
||
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||
|
|
||
|
Returns:
|
||
|
A score function.
|
||
|
"""
|
||
|
model_fn = get_model_fn(model, train=train)
|
||
|
|
||
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||
|
def score_fn(x, t, *args, **kwargs):
|
||
|
# Scale neural network output by standard deviation and flip sign
|
||
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||
|
labels = t * 999
|
||
|
score = model_fn(x, labels, *args, **kwargs)
|
||
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||
|
else:
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
labels = t * (sde.N - 1)
|
||
|
score = model_fn(x, labels, *args, **kwargs)
|
||
|
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||
|
labels.long()]
|
||
|
|
||
|
score = -score / std[:, None, None]
|
||
|
return score
|
||
|
|
||
|
elif isinstance(sde, sde_lib.VESDE):
|
||
|
def score_fn(x, t, *args, **kwargs):
|
||
|
if continuous:
|
||
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||
|
else:
|
||
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
||
|
labels = sde.T - t
|
||
|
labels *= sde.N - 1
|
||
|
labels = torch.round(labels).long()
|
||
|
|
||
|
score = model_fn(x, labels, *args, **kwargs)
|
||
|
return score
|
||
|
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||
|
|
||
|
return score_fn
|
||
|
|
||
|
|
||
|
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
||
|
regress=True, labels='max'):
|
||
|
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
||
|
|
||
|
def classifier_grad_fn(x, t, *args, **kwargs):
|
||
|
with torch.enable_grad():
|
||
|
x_in = x.detach().requires_grad_(True)
|
||
|
if regress:
|
||
|
assert labels in ['max', 'min']
|
||
|
logit = logit_fn(x_in, t, *args, **kwargs)
|
||
|
if labels == 'max':
|
||
|
prob = logit.sum()
|
||
|
elif labels == 'min':
|
||
|
prob = -logit.sum()
|
||
|
else:
|
||
|
logit = logit_fn(x_in, t, *args, **kwargs)
|
||
|
log_prob = F.log_softmax(logit, dim=-1)
|
||
|
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
||
|
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
||
|
return classifier_grad
|
||
|
|
||
|
return classifier_grad_fn
|
||
|
|
||
|
|
||
|
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
||
|
classifier_fn = get_model_fn(classifier, train=train)
|
||
|
|
||
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||
|
def logit_fn(x, t, *args, **kwargs):
|
||
|
# Scale neural network output by standard deviation and flip sign
|
||
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||
|
labels = t * 999
|
||
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
||
|
else:
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
labels = t * (sde.N - 1)
|
||
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
||
|
return logit
|
||
|
|
||
|
elif isinstance(sde, sde_lib.VESDE):
|
||
|
def logit_fn(x, t, *args, **kwargs):
|
||
|
if continuous:
|
||
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||
|
else:
|
||
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
||
|
labels = sde.T - t
|
||
|
labels *= sde.N - 1
|
||
|
labels = torch.round(labels).long()
|
||
|
logit = classifier_fn(x, labels, *args, **kwargs)
|
||
|
return logit
|
||
|
|
||
|
return logit_fn
|
||
|
|
||
|
|
||
|
def get_predictor_fn(sde, model, train=False, continuous=False):
|
||
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||
|
|
||
|
Args:
|
||
|
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||
|
model: A predictor model.
|
||
|
train: `True` for training and `False` for evaluation.
|
||
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||
|
|
||
|
Returns:
|
||
|
A score function.
|
||
|
"""
|
||
|
model_fn = get_model_fn(model, train=train)
|
||
|
|
||
|
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||
|
def predictor_fn(x, t, *args, **kwargs):
|
||
|
# Scale neural network output by standard deviation and flip sign
|
||
|
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||
|
labels = t * 999
|
||
|
pred = model_fn(x, labels, *args, **kwargs)
|
||
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||
|
else:
|
||
|
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||
|
labels = t * (sde.N - 1)
|
||
|
pred = model_fn(x, labels, *args, **kwargs)
|
||
|
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||
|
labels.long()]
|
||
|
|
||
|
return pred
|
||
|
|
||
|
elif isinstance(sde, sde_lib.VESDE):
|
||
|
def predictor_fn(x, t, *args, **kwargs):
|
||
|
if continuous:
|
||
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||
|
else:
|
||
|
# For VE-trained models, t=0 corresponds to the highest noise level
|
||
|
labels = sde.T - t
|
||
|
labels *= sde.N - 1
|
||
|
labels = torch.round(labels).long()
|
||
|
|
||
|
pred = model_fn(x, labels, *args, **kwargs)
|
||
|
return pred
|
||
|
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||
|
|
||
|
return predictor_fn
|
||
|
|
||
|
|
||
|
def to_flattened_numpy(x):
|
||
|
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||
|
return x.detach().cpu().numpy().reshape((-1,))
|
||
|
|
||
|
|
||
|
def from_flattened_numpy(x, shape):
|
||
|
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||
|
return torch.from_numpy(x.reshape(shape))
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def mask_adj2node(adj_mask):
|
||
|
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
||
|
|
||
|
Args:
|
||
|
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
||
|
|
||
|
Output:
|
||
|
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
||
|
"""
|
||
|
|
||
|
batch_size, max_num_nodes, _ = adj_mask.shape
|
||
|
|
||
|
node_mask = adj_mask[:, 0, :].clone()
|
||
|
node_mask[:, 0] = 1
|
||
|
|
||
|
return node_mask
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def get_rw_feat(k_step, dense_adj):
|
||
|
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
||
|
|
||
|
rw_list = []
|
||
|
deg = dense_adj.sum(-1, keepdims=True)
|
||
|
AD = dense_adj / (deg + 1e-8)
|
||
|
rw_list.append(AD)
|
||
|
|
||
|
for _ in range(k_step):
|
||
|
rw = torch.bmm(rw_list[-1], AD)
|
||
|
rw_list.append(rw)
|
||
|
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
||
|
|
||
|
rw_landing = torch.diagonal(
|
||
|
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
||
|
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
||
|
|
||
|
# get the shortest path distance indices
|
||
|
tmp_rw = rw_map.sort(dim=1)[0]
|
||
|
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
||
|
|
||
|
spd_onehot = torch.nn.functional.one_hot(
|
||
|
spd_ind, num_classes=k_step+1).to(torch.float)
|
||
|
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
||
|
|
||
|
return rw_landing, spd_onehot
|