"""Modified from https://github.com/uoguelph-mlrg/GGM-metrics""" import torch import torch.nn as nn import torch.nn.functional as F import dgl.function as fn from dgl.utils import expand_as_pair from dgl.nn import SumPooling, AvgPooling, MaxPooling class GINConv(nn.Module): def __init__(self, apply_func, aggregator_type, init_eps=0, learn_eps=False): super(GINConv, self).__init__() self.apply_func = apply_func self._aggregator_type = aggregator_type if aggregator_type == 'sum': self._reducer = fn.sum elif aggregator_type == 'max': self._reducer = fn.max elif aggregator_type == 'mean': self._reducer = fn.mean else: raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) # to specify whether eps is trainable or not. if learn_eps: self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps])) else: self.register_buffer('eps', torch.FloatTensor([init_eps])) def forward(self, graph, feat, edge_weight=None): r""" Description ----------- Compute Graph Isomorphism Network layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor or pair of torch.Tensor If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`. If ``apply_func`` is not None, :math:`D_{in}` should fit the input dimensionality requirement of ``apply_func``. edge_weight : torch.Tensor, optional Optional tensor on the edge. If given, the convolution will weight with regard to the message. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is the output dimensionality of ``apply_func``. If ``apply_func`` is None, :math:`D_{out}` should be the same as input dimensionality. """ with graph.local_scope(): aggregate_fn = self.concat_edge_msg # aggregate_fn = fn.copy_src('h', 'm') if edge_weight is not None: assert edge_weight.shape[0] == graph.number_of_edges() graph.edata['_edge_weight'] = edge_weight aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') feat_src, feat_dst = expand_as_pair(feat, graph) graph.srcdata['h'] = feat_src graph.update_all(aggregate_fn, self._reducer('m', 'neigh')) diff = torch.tensor(graph.dstdata['neigh'].shape[1: ]) - torch.tensor(feat_dst.shape[1: ]) zeros = torch.zeros(feat_dst.shape[0], *diff).to(feat_dst.device) feat_dst = torch.cat([feat_dst, zeros], dim=1) rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] if self.apply_func is not None: rst = self.apply_func(rst) return rst def concat_edge_msg(self, edges): if self.edge_feat_loc not in edges.data: return {'m': edges.src['h']} else: m = torch.cat([edges.src['h'], edges.data[self.edge_feat_loc]], dim=1) return {'m': m} class ApplyNodeFunc(nn.Module): """Update the node feature hv with MLP, BN and ReLU.""" def __init__(self, mlp): super(ApplyNodeFunc, self).__init__() self.mlp = mlp self.bn = nn.BatchNorm1d(self.mlp.output_dim) def forward(self, h): h = self.mlp(h) h = self.bn(h) h = F.relu(h) return h class MLP(nn.Module): """MLP with linear output""" def __init__(self, num_layers, input_dim, hidden_dim, output_dim): """MLP layers construction Paramters --------- num_layers: int The number of linear layers input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction """ super(MLP, self).__init__() self.linear_or_not = True # default is linear model self.num_layers = num_layers self.output_dim = output_dim if num_layers < 1: raise ValueError("number of layers should be positive!") elif num_layers == 1: # Linear model self.linear = nn.Linear(input_dim, output_dim) else: # Multi-layer model self.linear_or_not = False self.linears = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.linears.append(nn.Linear(input_dim, hidden_dim)) for layer in range(num_layers - 2): self.linears.append(nn.Linear(hidden_dim, hidden_dim)) self.linears.append(nn.Linear(hidden_dim, output_dim)) for layer in range(num_layers - 1): self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) def forward(self, x): if self.linear_or_not: # If linear model return self.linear(x) else: # If MLP h = x for i in range(self.num_layers - 1): h = F.relu(self.batch_norms[i](self.linears[i](h))) return self.linears[-1](h) class GIN(nn.Module): """GIN model""" def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, graph_pooling_type, neighbor_pooling_type, edge_feat_dim=0, final_dropout=0.0, learn_eps=False, output_dim=1, **kwargs): """model parameters setting Paramters --------- num_layers: int The number of linear layers in the neural network num_mlp_layers: int The number of linear layers in mlps input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction final_dropout: float dropout ratio on the final linear layer learn_eps: boolean If True, learn epsilon to distinguish center nodes from neighbors If False, aggregate neighbors and center nodes altogether. neighbor_pooling_type: str how to aggregate neighbors (sum, mean, or max) graph_pooling_type: str how to aggregate entire nodes in a graph (sum, mean or max) """ super().__init__() def init_weights_orthogonal(m): if isinstance(m, nn.Linear): torch.nn.init.orthogonal_(m.weight) elif isinstance(m, MLP): if hasattr(m, 'linears'): m.linears.apply(init_weights_orthogonal) else: m.linear.apply(init_weights_orthogonal) elif isinstance(m, nn.ModuleList): pass else: raise Exception() self.num_layers = num_layers self.learn_eps = learn_eps # List of MLPs self.ginlayers = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() # self.preprocess_nodes = PreprocessNodeAttrs( # node_attrs=node_preprocess, output_dim=node_preprocess_output_dim) # print(input_dim) for layer in range(self.num_layers - 1): if layer == 0: mlp = MLP(num_mlp_layers, input_dim + edge_feat_dim, hidden_dim, hidden_dim) else: mlp = MLP(num_mlp_layers, hidden_dim + edge_feat_dim, hidden_dim, hidden_dim) if kwargs['init'] == 'orthogonal': init_weights_orthogonal(mlp) self.ginlayers.append( GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # Linear function for graph poolings of output of each layer # which maps the output of different layers into a prediction score self.linears_prediction = torch.nn.ModuleList() for layer in range(num_layers): if layer == 0: self.linears_prediction.append( nn.Linear(input_dim, output_dim)) else: self.linears_prediction.append( nn.Linear(hidden_dim, output_dim)) if kwargs['init'] == 'orthogonal': # print('orthogonal') self.linears_prediction.apply(init_weights_orthogonal) self.drop = nn.Dropout(final_dropout) if graph_pooling_type == 'sum': self.pool = SumPooling() elif graph_pooling_type == 'mean': self.pool = AvgPooling() elif graph_pooling_type == 'max': self.pool = MaxPooling() else: raise NotImplementedError def forward(self, g, h): # list of hidden representation at each layer (including input) hidden_rep = [h] # h = self.preprocess_nodes(h) for i in range(self.num_layers - 1): h = self.ginlayers[i](g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform pooling over all nodes in each graph in every layer for i, h in enumerate(hidden_rep): pooled_h = self.pool(g, h) score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) return score_over_layer def get_graph_embed(self, g, h): self.eval() with torch.no_grad(): # return self.forward(g, h).detach().numpy() hidden_rep = [] # h = self.preprocess_nodes(h) for i in range(self.num_layers - 1): h = self.ginlayers[i](g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) # perform pooling over all nodes in each graph in every layer graph_embed = torch.Tensor([]).to(self.device) for i, h in enumerate(hidden_rep): pooled_h = self.pool(g, h) graph_embed = torch.cat([graph_embed, pooled_h], dim = 1) return graph_embed def get_graph_embed_no_cat(self, g, h): self.eval() with torch.no_grad(): hidden_rep = [] # h = self.preprocess_nodes(h) for i in range(self.num_layers - 1): h = self.ginlayers[i](g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) return self.pool(g, hidden_rep[-1]).to(self.device) @property def edge_feat_loc(self): return self.ginlayers[0].edge_feat_loc @edge_feat_loc.setter def edge_feat_loc(self, loc): for layer in self.ginlayers: layer.edge_feat_loc = loc