import os import logging import torch from torch_scatter import scatter import shutil @torch.no_grad() def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None): """Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix. Args: edge_index (LongTensor): The edge indices. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) max_num_nodes (int, optional): The size of the output node dimension. (default: :obj:`None`) Returns: adj: [batch_size, max_num_nodes, max_num_nodes] Dense adjacency matrices. mask: Mask for dense adjacency matrices. """ if batch is None: batch = edge_index.new_zeros(edge_index.max().item() + 1) batch_size = batch.max().item() + 1 one = batch.new_ones(batch.size(0)) num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add') cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) idx0 = batch[edge_index[0]] idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] if max_num_nodes is None: max_num_nodes = num_nodes.max().item() elif idx1.max() >= max_num_nodes or idx2.max() >= max_num_nodes: mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes) idx0 = idx0[mask] idx1 = idx1[mask] idx2 = idx2[mask] edge_attr = None if edge_attr is None else edge_attr[mask] if edge_attr is None: edge_attr = torch.ones(idx0.numel(), device=edge_index.device) size = [batch_size, max_num_nodes, max_num_nodes] size += list(edge_attr.size())[1:] adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device) flattened_size = batch_size * max_num_nodes * max_num_nodes adj = adj.view([flattened_size] + list(adj.size())[3:]) idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2 scatter(edge_attr, idx, dim=0, out=adj, reduce='add') adj = adj.view(size) node_idx = torch.arange(batch.size(0), dtype=torch.long, device=edge_index.device) node_idx = (node_idx - cum_nodes[batch]) + (batch * max_num_nodes) mask = torch.zeros(batch_size * max_num_nodes, dtype=adj.dtype, device=adj.device) mask[node_idx] = 1 mask = mask.view(batch_size, max_num_nodes) mask = mask[:, None, :] * mask[:, :, None] return adj, mask def restore_checkpoint_partial(model, pretrained_stdict): model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_stdict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) return model def restore_checkpoint(ckpt_dir, state, device, resume=False): if not resume: os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) return state elif not os.path.exists(ckpt_dir): if not os.path.exists(os.path.dirname(ckpt_dir)): os.makedirs(os.path.dirname(ckpt_dir)) logging.warning(f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input") return state else: loaded_state = torch.load(ckpt_dir, map_location=device) for k in state: if k in ['optimizer', 'model', 'ema']: state[k].load_state_dict(loaded_state[k]) else: state[k] = loaded_state[k] return state def save_checkpoint(ckpt_dir, state, step, save_step, is_best): saved_state = {} for k in state: if k in ['optimizer', 'model', 'ema']: saved_state.update({k: state[k].state_dict()}) else: saved_state.update({k: state[k]}) os.makedirs(ckpt_dir, exist_ok=True) torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar')) if is_best: shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar')) # remove the ckpt except is_best state for ckpt_file in sorted(os.listdir(ckpt_dir)): if not ckpt_file.startswith('checkpoint'): continue if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'): os.remove(os.path.join(ckpt_dir, ckpt_file)) def floyed(r): """ :param r: a numpy NxN matrix with float 0,1 :return: a numpy NxN matrix with float 0,1 """ # r = np.array(r) if type(r) == torch.Tensor: r = r.cpu().numpy() N = r.shape[0] # import pdb; pdb.set_trace() for k in range(N): for i in range(N): for j in range(N): if r[i, k] > 0 and r[k, j] > 0: r[i, j] = 1 return r def aug_mask(adj, algo='long_range', data='NASBench201'): if len(adj.shape) == 2: adj = adj.unsqueeze(0) if data.lower() in ['nasbench201', 'ofa']: assert len(adj.shape) == 3 r = adj[0].clone().detach() if algo == 'long_range': mask_i = torch.from_numpy(long_range(r)).float().to(adj.device) elif algo == 'floyed': mask_i = torch.from_numpy(floyed(r)).float().to(adj.device) else: mask_i = r masks = [mask_i] * adj.size(0) return torch.stack(masks) else: masks = [] for r in adj: if algo == 'long_range': mask_i = torch.from_numpy(long_range(r)).float().to(adj.device) elif algo == 'floyed': mask_i = torch.from_numpy(floyed(r)).float().to(adj.device) else: mask_i = r masks.append(mask_i) return torch.stack(masks) def long_range(r): """ :param r: a numpy NxN matrix with float 0,1 :return: a numpy NxN matrix with float 0,1 """ # r = np.array(r) if type(r) == torch.Tensor: r = r.cpu().numpy() N = r.shape[0] for j in range(1, N): col_j = r[:, j][:j] in_to_j = [i for i, val in enumerate(col_j) if val > 0] if len(in_to_j) > 0: for i in in_to_j: col_i = r[:, i][:i] in_to_i = [i for i, val in enumerate(col_i) if val > 0] if len(in_to_i) > 0: for k in in_to_i: r[k, j] = 1 return r def dense_adj(graph_data, max_num_nodes, scaler=None, dequantization=False): """Convert PyG DataBatch to dense adjacency matrices. Args: graph_data: DataBatch object. max_num_nodes: The size of the output node dimension. scaler: Data normalizer. dequantization: uniform dequantization. Returns: adj: Dense adjacency matrices. mask: Mask for adjacency matrices. """ adj, adj_mask = to_dense_adj(graph_data.edge_index, graph_data.batch, max_num_nodes=max_num_nodes) # [B, N, N] # adj: [32, 20, 20] / adj_mask: [32, 20, 20] if dequantization: noise = torch.rand_like(adj) noise = torch.tril(noise, -1) noise = noise + noise.transpose(1, 2) adj = (noise + adj) / 2. adj = scaler(adj[:, None, :, :]) # [32, 1, 20, 20] # set diag = 0 in adj_mask adj_mask = torch.tril(adj_mask, -1) # [32, 20, 20] adj_mask = adj_mask + adj_mask.transpose(1, 2) return adj, adj_mask[:, None, :, :] def adj2graph(adj, sample_nodes): """Covert the PyTorch tensor adjacency matrices to numpy array. Args: adj: [Batch_size, channel, Max_node, Max_node], assume channel=1 sample_nodes: [Batch_size] """ adj_list = [] # discretization adj[adj >= 0.5] = 1. adj[adj < 0.5] = 0. for i in range(adj.shape[0]): adj_tmp = adj[i, 0] # symmetric adj_tmp = torch.tril(adj_tmp, -1) adj_tmp = adj_tmp + adj_tmp.transpose(0, 1) # truncate adj_tmp = adj_tmp.cpu().numpy()[:sample_nodes[i], :sample_nodes[i]] adj_list.append(adj_tmp) return adj_list def quantize(x, adj, alpha=0.5, qtype='threshold'): """Covert the PyTorch tensor x, adj matrices to numpy array. Args: x: [Batch_size, Max_node, N_vocab] adj: [Batch_size, Max_node, Max_node] """ x_list = [] if qtype == 'threshold': # discretization x[x >= alpha] = 1. x[x < alpha] = 0. # adj = adj[0] for i in range(x.shape[0]): x_tmp = x[i] x_tmp = x_tmp.cpu().numpy() x_list.append(x_tmp) elif qtype == 'argmax': am = torch.argmax(x, dim=2, keepdim=True) # [Batch_size, Max_node] # gather = torch.gather(x, 2, am) x = torch.zeros_like(x).scatter(2, am, value=1) for i in range(x.shape[0]): x_tmp = x[i] x_tmp = x_tmp.cpu().numpy() x_list.append(x_tmp) return x_list