154 lines
5.5 KiB
Python
154 lines
5.5 KiB
Python
import torch
|
|
from torch.nn import Parameter
|
|
import torch.nn.functional as F
|
|
import math
|
|
from typing import Any
|
|
|
|
|
|
def glorot(tensor):
|
|
if tensor is not None:
|
|
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
|
|
tensor.data.uniform_(-stdv, stdv)
|
|
|
|
def zeros(tensor):
|
|
if tensor is not None:
|
|
tensor.data.fill_(0)
|
|
|
|
def reset(value: Any):
|
|
if hasattr(value, 'reset_parameters'):
|
|
value.reset_parameters()
|
|
else:
|
|
for child in value.children() if hasattr(value, 'children') else []:
|
|
reset(child)
|
|
|
|
# -------- GCN layer --------
|
|
class DenseGCNConv(torch.nn.Module):
|
|
r"""See :class:`torch_geometric.nn.conv.GCNConv`.
|
|
"""
|
|
def __init__(self, in_channels, out_channels, improved=False, bias=True):
|
|
super(DenseGCNConv, self).__init__()
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.improved = improved
|
|
|
|
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
|
|
|
|
if bias:
|
|
self.bias = Parameter(torch.Tensor(out_channels))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
glorot(self.weight)
|
|
zeros(self.bias)
|
|
|
|
|
|
def forward(self, x, adj, mask=None, add_loop=True):
|
|
r"""
|
|
Args:
|
|
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
|
|
\times N \times F}`, with batch-size :math:`B`, (maximum)
|
|
number of nodes :math:`N` for each graph, and feature
|
|
dimension :math:`F`.
|
|
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
|
|
\times N \times N}`. The adjacency tensor is broadcastable in
|
|
the batch dimension, resulting in a shared adjacency matrix for
|
|
the complete batch.
|
|
mask (BoolTensor, optional): Mask matrix
|
|
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
the valid nodes for each graph. (default: :obj:`None`)
|
|
add_loop (bool, optional): If set to :obj:`False`, the layer will
|
|
not automatically add self-loops to the adjacency matrices.
|
|
(default: :obj:`True`)
|
|
"""
|
|
x = x.unsqueeze(0) if x.dim() == 2 else x
|
|
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
|
|
B, N, _ = adj.size()
|
|
|
|
if add_loop:
|
|
adj = adj.clone()
|
|
idx = torch.arange(N, dtype=torch.long, device=adj.device)
|
|
adj[:, idx, idx] = 1 if not self.improved else 2
|
|
|
|
out = torch.matmul(x, self.weight)
|
|
deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)
|
|
|
|
adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
|
|
out = torch.matmul(adj, out)
|
|
|
|
if self.bias is not None:
|
|
out = out + self.bias
|
|
|
|
if mask is not None:
|
|
out = out * mask.view(B, N, 1).to(x.dtype)
|
|
|
|
return out
|
|
|
|
|
|
def __repr__(self):
|
|
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
|
|
self.out_channels)
|
|
|
|
# -------- MLP layer --------
|
|
class MLP(torch.nn.Module):
|
|
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
|
|
"""
|
|
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
|
|
input_dim: dimensionality of input features
|
|
hidden_dim: dimensionality of hidden units at ALL layers
|
|
output_dim: number of classes for prediction
|
|
num_classes: the number of classes of input, to be treated with different gains and biases,
|
|
(see the definition of class `ConditionalLayer1d`)
|
|
"""
|
|
|
|
super(MLP, self).__init__()
|
|
|
|
self.linear_or_not = True # default is linear model
|
|
self.num_layers = num_layers
|
|
self.use_bn = use_bn
|
|
self.activate_func = activate_func
|
|
|
|
if num_layers < 1:
|
|
raise ValueError("number of layers should be positive!")
|
|
elif num_layers == 1:
|
|
# Linear model
|
|
self.linear = torch.nn.Linear(input_dim, output_dim)
|
|
else:
|
|
# Multi-layer model
|
|
self.linear_or_not = False
|
|
self.linears = torch.nn.ModuleList()
|
|
|
|
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
|
|
for layer in range(num_layers - 2):
|
|
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
|
|
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
|
|
|
|
if self.use_bn:
|
|
self.batch_norms = torch.nn.ModuleList()
|
|
for layer in range(num_layers - 1):
|
|
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
|
|
|
|
|
|
def forward(self, x):
|
|
"""
|
|
:param x: [num_classes * batch_size, N, F_i], batch of node features
|
|
note that in self.cond_layers[layer],
|
|
`x` is splited into `num_classes` groups in dim=0,
|
|
and then treated with different gains and biases
|
|
"""
|
|
if self.linear_or_not:
|
|
# If linear model
|
|
return self.linear(x)
|
|
else:
|
|
# If MLP
|
|
h = x
|
|
for layer in range(self.num_layers - 1):
|
|
h = self.linears[layer](h)
|
|
if self.use_bn:
|
|
h = self.batch_norms[layer](h)
|
|
h = self.activate_func(h)
|
|
return self.linears[self.num_layers - 1](h)
|