import torch from torch.nn import Parameter import torch.nn.functional as F import math from typing import Any def glorot(tensor): if tensor is not None: stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) tensor.data.uniform_(-stdv, stdv) def zeros(tensor): if tensor is not None: tensor.data.fill_(0) def reset(value: Any): if hasattr(value, 'reset_parameters'): value.reset_parameters() else: for child in value.children() if hasattr(value, 'children') else []: reset(child) # -------- GCN layer -------- class DenseGCNConv(torch.nn.Module): r"""See :class:`torch_geometric.nn.conv.GCNConv`. """ def __init__(self, in_channels, out_channels, improved=False, bias=True): super(DenseGCNConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): glorot(self.weight) zeros(self.bias) def forward(self, x, adj, mask=None, add_loop=True): r""" Args: x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`. The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch. mask (BoolTensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) add_loop (bool, optional): If set to :obj:`False`, the layer will not automatically add self-loops to the adjacency matrices. (default: :obj:`True`) """ x = x.unsqueeze(0) if x.dim() == 2 else x adj = adj.unsqueeze(0) if adj.dim() == 2 else adj B, N, _ = adj.size() if add_loop: adj = adj.clone() idx = torch.arange(N, dtype=torch.long, device=adj.device) adj[:, idx, idx] = 1 if not self.improved else 2 out = torch.matmul(x, self.weight) deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5) adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2) out = torch.matmul(adj, out) if self.bias is not None: out = out + self.bias if mask is not None: out = out * mask.view(B, N, 1).to(x.dtype) return out def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) # -------- MLP layer -------- class MLP(torch.nn.Module): def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu): """ num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. input_dim: dimensionality of input features hidden_dim: dimensionality of hidden units at ALL layers output_dim: number of classes for prediction num_classes: the number of classes of input, to be treated with different gains and biases, (see the definition of class `ConditionalLayer1d`) """ super(MLP, self).__init__() self.linear_or_not = True # default is linear model self.num_layers = num_layers self.use_bn = use_bn self.activate_func = activate_func if num_layers < 1: raise ValueError("number of layers should be positive!") elif num_layers == 1: # Linear model self.linear = torch.nn.Linear(input_dim, output_dim) else: # Multi-layer model self.linear_or_not = False self.linears = torch.nn.ModuleList() self.linears.append(torch.nn.Linear(input_dim, hidden_dim)) for layer in range(num_layers - 2): self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim)) self.linears.append(torch.nn.Linear(hidden_dim, output_dim)) if self.use_bn: self.batch_norms = torch.nn.ModuleList() for layer in range(num_layers - 1): self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim)) def forward(self, x): """ :param x: [num_classes * batch_size, N, F_i], batch of node features note that in self.cond_layers[layer], `x` is splited into `num_classes` groups in dim=0, and then treated with different gains and biases """ if self.linear_or_not: # If linear model return self.linear(x) else: # If MLP h = x for layer in range(self.num_layers - 1): h = self.linears[layer](h) if self.use_bn: h = self.batch_norms[layer](h) h = self.activate_func(h) return self.linears[self.num_layers - 1](h)