diffusionNAG/MobileNetV3/models/utils.py
2024-03-15 14:38:51 +00:00

302 lines
11 KiB
Python

import torch
import torch.nn.functional as F
import sde_lib
import numpy as np
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(
f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def create_model(config):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
score_model = score_model.to(config.device)
if 'load_pretrained' in config['training'].keys() and config.training.load_pretrained:
from utils import restore_checkpoint_partial
score_model = restore_checkpoint_partial(score_model, torch.load(config.training.pretrained_model_path, map_location=config.device)['model'])
# score_model = torch.nn.DataParallel(score_model)
return score_model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels, *args, **kwargs):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data (Adjacency matrices).
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
mask: Mask for adjacency matrices.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels, *args, **kwargs)
else:
model.train()
return model(x, labels, *args, **kwargs)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def score_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
score = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
score = -score / std[:, None, None]
return score
elif isinstance(sde, sde_lib.VESDE):
def score_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
score = model_fn(x, labels, *args, **kwargs)
return score
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
regress=True, labels='max'):
logit_fn = get_logit_fn(sde, classifier, train, continuous)
def classifier_grad_fn(x, t, *args, **kwargs):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
if regress:
assert labels in ['max', 'min']
logit = logit_fn(x_in, t, *args, **kwargs)
prob = logit.sum()
else:
logit = logit_fn(x_in, t, *args, **kwargs)
# prob = torch.nn.functional.log_softmax(logit, dim=-1)[torch.arange(labels.shape[0]), labels].sum()
log_prob = F.log_softmax(logit, dim=-1)
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
# prob.backward()
# classifier_grad = x_in.grad
classifier_grad = torch.autograd.grad(prob, x_in)[0]
return classifier_grad
return classifier_grad_fn
def get_logit_fn(sde, classifier, train=False, continuous=False):
classifier_fn = get_model_fn(classifier, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def logit_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
logit = classifier_fn(x, labels, *args, **kwargs)
# std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
logit = classifier_fn(x, labels, *args, **kwargs)
# std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
# labels.long()]
# score = -score / std[:, None, None]
return logit
elif isinstance(sde, sde_lib.VESDE):
def logit_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
return logit_fn
def get_predictor_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A predictor model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def predictor_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
pred = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
pred = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
# score = -score / std[:, None, None]
return pred
elif isinstance(sde, sde_lib.VESDE):
def predictor_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
pred = model_fn(x, labels, *args, **kwargs)
return pred
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return predictor_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))
@torch.no_grad()
def mask_adj2node(adj_mask):
"""Convert batched adjacency mask matrices to batched node mask matrices.
Args:
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
Output:
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
"""
batch_size, max_num_nodes, _ = adj_mask.shape
node_mask = adj_mask[:, 0, :].clone()
node_mask[:, 0] = 1
return node_mask
@torch.no_grad()
def get_rw_feat(k_step, dense_adj):
"""Compute k_step Random Walk for given dense adjacency matrix."""
rw_list = []
deg = dense_adj.sum(-1, keepdims=True)
AD = dense_adj / (deg + 1e-8)
rw_list.append(AD)
for _ in range(k_step):
rw = torch.bmm(rw_list[-1], AD)
rw_list.append(rw)
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
rw_landing = torch.diagonal(
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
# get the shortest path distance indices
tmp_rw = rw_map.sort(dim=1)[0]
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
spd_onehot = torch.nn.functional.one_hot(
spd_ind, num_classes=k_step+1).to(torch.float)
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
return rw_landing, spd_onehot