diffusionNAG/MobileNetV3/utils.py
2024-03-15 14:38:51 +00:00

270 lines
9.1 KiB
Python

import os
import logging
import torch
from torch_scatter import scatter
import shutil
@torch.no_grad()
def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None):
"""Converts batched sparse adjacency matrices given by edge indices and
edge attributes to a single dense batched adjacency matrix.
Args:
edge_index (LongTensor): The edge indices.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
features. (default: :obj:`None`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
Returns:
adj: [batch_size, max_num_nodes, max_num_nodes] Dense adjacency matrices.
mask: Mask for dense adjacency matrices.
"""
if batch is None:
batch = edge_index.new_zeros(edge_index.max().item() + 1)
batch_size = batch.max().item() + 1
one = batch.new_ones(batch.size(0))
num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
idx0 = batch[edge_index[0]]
idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
if max_num_nodes is None:
max_num_nodes = num_nodes.max().item()
elif idx1.max() >= max_num_nodes or idx2.max() >= max_num_nodes:
mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
idx0 = idx0[mask]
idx1 = idx1[mask]
idx2 = idx2[mask]
edge_attr = None if edge_attr is None else edge_attr[mask]
if edge_attr is None:
edge_attr = torch.ones(idx0.numel(), device=edge_index.device)
size = [batch_size, max_num_nodes, max_num_nodes]
size += list(edge_attr.size())[1:]
adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device)
flattened_size = batch_size * max_num_nodes * max_num_nodes
adj = adj.view([flattened_size] + list(adj.size())[3:])
idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
scatter(edge_attr, idx, dim=0, out=adj, reduce='add')
adj = adj.view(size)
node_idx = torch.arange(batch.size(0), dtype=torch.long, device=edge_index.device)
node_idx = (node_idx - cum_nodes[batch]) + (batch * max_num_nodes)
mask = torch.zeros(batch_size * max_num_nodes, dtype=adj.dtype, device=adj.device)
mask[node_idx] = 1
mask = mask.view(batch_size, max_num_nodes)
mask = mask[:, None, :] * mask[:, :, None]
return adj, mask
def restore_checkpoint_partial(model, pretrained_stdict):
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_stdict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return model
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def save_checkpoint(ckpt_dir, state, step, save_step, is_best):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def floyed(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
# import pdb; pdb.set_trace()
for k in range(N):
for i in range(N):
for j in range(N):
if r[i, k] > 0 and r[k, j] > 0:
r[i, j] = 1
return r
def aug_mask(adj, algo='long_range', data='NASBench201'):
if len(adj.shape) == 2:
adj = adj.unsqueeze(0)
if data.lower() in ['nasbench201', 'ofa']:
assert len(adj.shape) == 3
r = adj[0].clone().detach()
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks = [mask_i] * adj.size(0)
return torch.stack(masks)
else:
masks = []
for r in adj:
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks.append(mask_i)
return torch.stack(masks)
def long_range(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for j in range(1, N):
col_j = r[:, j][:j]
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
if len(in_to_j) > 0:
for i in in_to_j:
col_i = r[:, i][:i]
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
if len(in_to_i) > 0:
for k in in_to_i:
r[k, j] = 1
return r
def dense_adj(graph_data, max_num_nodes, scaler=None, dequantization=False):
"""Convert PyG DataBatch to dense adjacency matrices.
Args:
graph_data: DataBatch object.
max_num_nodes: The size of the output node dimension.
scaler: Data normalizer.
dequantization: uniform dequantization.
Returns:
adj: Dense adjacency matrices.
mask: Mask for adjacency matrices.
"""
adj, adj_mask = to_dense_adj(graph_data.edge_index, graph_data.batch, max_num_nodes=max_num_nodes) # [B, N, N]
# adj: [32, 20, 20] / adj_mask: [32, 20, 20]
if dequantization:
noise = torch.rand_like(adj)
noise = torch.tril(noise, -1)
noise = noise + noise.transpose(1, 2)
adj = (noise + adj) / 2.
adj = scaler(adj[:, None, :, :]) # [32, 1, 20, 20]
# set diag = 0 in adj_mask
adj_mask = torch.tril(adj_mask, -1) # [32, 20, 20]
adj_mask = adj_mask + adj_mask.transpose(1, 2)
return adj, adj_mask[:, None, :, :]
def adj2graph(adj, sample_nodes):
"""Covert the PyTorch tensor adjacency matrices to numpy array.
Args:
adj: [Batch_size, channel, Max_node, Max_node], assume channel=1
sample_nodes: [Batch_size]
"""
adj_list = []
# discretization
adj[adj >= 0.5] = 1.
adj[adj < 0.5] = 0.
for i in range(adj.shape[0]):
adj_tmp = adj[i, 0]
# symmetric
adj_tmp = torch.tril(adj_tmp, -1)
adj_tmp = adj_tmp + adj_tmp.transpose(0, 1)
# truncate
adj_tmp = adj_tmp.cpu().numpy()[:sample_nodes[i], :sample_nodes[i]]
adj_list.append(adj_tmp)
return adj_list
def quantize(x, adj, alpha=0.5, qtype='threshold'):
"""Covert the PyTorch tensor x, adj matrices to numpy array.
Args:
x: [Batch_size, Max_node, N_vocab]
adj: [Batch_size, Max_node, Max_node]
"""
x_list = []
if qtype == 'threshold':
# discretization
x[x >= alpha] = 1.
x[x < alpha] = 0.
# adj = adj[0]
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
elif qtype == 'argmax':
am = torch.argmax(x, dim=2, keepdim=True) # [Batch_size, Max_node]
# gather = torch.gather(x, 2, am)
x = torch.zeros_like(x).scatter(2, am, value=1)
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
return x_list