first commit
This commit is contained in:
		
							
								
								
									
										117
									
								
								MobileNetV3/models/GDSS/attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								MobileNetV3/models/GDSS/attention.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| import math | ||||
| import torch | ||||
| from torch.nn import Parameter | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from models.GDSS.layers import DenseGCNConv, MLP | ||||
| # from ..utils.graph_utils import mask_adjs, mask_x | ||||
| from .graph_utils import mask_x, mask_adjs | ||||
|  | ||||
|  | ||||
| # -------- Graph Multi-Head Attention (GMH) -------- | ||||
| # -------- From Baek et al. (2021) -------- | ||||
| class Attention(torch.nn.Module): | ||||
|  | ||||
|     def __init__(self, in_dim, attn_dim, out_dim, num_heads=4, conv='GCN'): | ||||
|         super(Attention, self).__init__() | ||||
|         self.num_heads = num_heads | ||||
|         self.attn_dim = attn_dim | ||||
|         self.out_dim = out_dim | ||||
|         self.conv = conv | ||||
|  | ||||
|         self.gnn_q, self.gnn_k, self.gnn_v = self.get_gnn(in_dim, attn_dim, out_dim, conv) | ||||
|         self.activation = torch.tanh  | ||||
|         self.softmax_dim = 2 | ||||
|  | ||||
|     def forward(self, x, adj, flags, attention_mask=None): | ||||
|  | ||||
|         if self.conv == 'GCN': | ||||
|             Q = self.gnn_q(x, adj)  | ||||
|             K = self.gnn_k(x, adj)  | ||||
|         else: | ||||
|             Q = self.gnn_q(x)  | ||||
|             K = self.gnn_k(x) | ||||
|  | ||||
|         V = self.gnn_v(x, adj)  | ||||
|         dim_split = self.attn_dim // self.num_heads | ||||
|         Q_ = torch.cat(Q.split(dim_split, 2), 0) | ||||
|         K_ = torch.cat(K.split(dim_split, 2), 0) | ||||
|  | ||||
|         if attention_mask is not None: | ||||
|             attention_mask = torch.cat([attention_mask for _ in range(self.num_heads)], 0) | ||||
|             attention_score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) | ||||
|             A = self.activation( attention_mask + attention_score ) | ||||
|         else: | ||||
|             A = self.activation( Q_.bmm(K_.transpose(1,2))/math.sqrt(self.out_dim) ) # (B x num_heads) x N x N | ||||
|          | ||||
|         # -------- (B x num_heads) x N x N -------- | ||||
|         A = A.view(-1, *adj.shape) | ||||
|         A = A.mean(dim=0) | ||||
|         A = (A + A.transpose(-1,-2))/2  | ||||
|  | ||||
|         return V, A  | ||||
|  | ||||
|     def get_gnn(self, in_dim, attn_dim, out_dim, conv='GCN'): | ||||
|  | ||||
|         if conv == 'GCN': | ||||
|             gnn_q = DenseGCNConv(in_dim, attn_dim) | ||||
|             gnn_k = DenseGCNConv(in_dim, attn_dim) | ||||
|             gnn_v = DenseGCNConv(in_dim, out_dim) | ||||
|  | ||||
|             return gnn_q, gnn_k, gnn_v | ||||
|  | ||||
|         elif conv == 'MLP': | ||||
|             num_layers=2 | ||||
|             gnn_q = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh) | ||||
|             gnn_k = MLP(num_layers, in_dim, 2*attn_dim, attn_dim, activate_func=torch.tanh) | ||||
|             gnn_v = DenseGCNConv(in_dim, out_dim) | ||||
|  | ||||
|             return gnn_q, gnn_k, gnn_v | ||||
|  | ||||
|         else: | ||||
|             raise NotImplementedError(f'{conv} not implemented.') | ||||
|  | ||||
|  | ||||
| # -------- Layer of ScoreNetworkA -------- | ||||
| class AttentionLayer(torch.nn.Module): | ||||
|  | ||||
|     def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,  | ||||
|                     num_heads=4, conv='GCN'): | ||||
|  | ||||
|         super(AttentionLayer, self).__init__() | ||||
|  | ||||
|         self.attn = torch.nn.ModuleList() | ||||
|         for _ in range(input_dim): | ||||
|             self.attn_dim =  attn_dim  | ||||
|             self.attn.append(Attention(conv_input_dim, self.attn_dim, conv_output_dim, | ||||
|                                         num_heads=num_heads, conv=conv)) | ||||
|  | ||||
|         self.hidden_dim = 2*max(input_dim, output_dim) | ||||
|         self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu) | ||||
|         self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,  | ||||
|                                     use_bn=False, activate_func=F.elu) | ||||
|  | ||||
|     def forward(self, x, adj, flags): | ||||
|         """ | ||||
|  | ||||
|         :param x:  B x N x F_i | ||||
|         :param adj: B x C_i x N x N | ||||
|         :return: x_out: B x N x F_o, adj_out: B x C_o x N x N | ||||
|         """ | ||||
|         mask_list = [] | ||||
|         x_list = [] | ||||
|         for _ in range(len(self.attn)): | ||||
|             _x, mask = self.attn[_](x, adj[:,_,:,:], flags) | ||||
|             mask_list.append(mask.unsqueeze(-1)) | ||||
|             x_list.append(_x) | ||||
|         x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags) | ||||
|         x_out = torch.tanh(x_out) | ||||
|  | ||||
|         mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1) | ||||
|         shape = mlp_in.shape | ||||
|         mlp_out = self.mlp(mlp_in.view(-1, shape[-1])) | ||||
|         _adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2) | ||||
|         _adj = _adj + _adj.transpose(-1,-2) | ||||
|         adj_out = mask_adjs(_adj, flags) | ||||
|  | ||||
|         return x_out, adj_out | ||||
							
								
								
									
										209
									
								
								MobileNetV3/models/GDSS/graph_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								MobileNetV3/models/GDSS/graph_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,209 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| # -------- Mask batch of node features with 0-1 flags tensor -------- | ||||
| def mask_x(x, flags): | ||||
|  | ||||
|     if flags is None: | ||||
|         flags = torch.ones((x.shape[0], x.shape[1]), device=x.device) | ||||
|     return x * flags[:,:,None] | ||||
|  | ||||
|  | ||||
| # -------- Mask batch of adjacency matrices with 0-1 flags tensor -------- | ||||
| def mask_adjs(adjs, flags): | ||||
|     """ | ||||
|     :param adjs:  B x N x N or B x C x N x N | ||||
|     :param flags: B x N | ||||
|     :return: | ||||
|     """ | ||||
|     if flags is None: | ||||
|         flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device) | ||||
|  | ||||
|     if len(adjs.shape) == 4: | ||||
|         flags = flags.unsqueeze(1)  # B x 1 x N | ||||
|     adjs = adjs * flags.unsqueeze(-1) | ||||
|     adjs = adjs * flags.unsqueeze(-2) | ||||
|     return adjs | ||||
|  | ||||
|  | ||||
| # -------- Create flags tensor from graph dataset -------- | ||||
| def node_flags(adj, eps=1e-5): | ||||
|  | ||||
|     flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32) | ||||
|  | ||||
|     if len(flags.shape)==3: | ||||
|         flags = flags[:,0,:] | ||||
|     return flags | ||||
|  | ||||
|  | ||||
| # -------- Create initial node features -------- | ||||
| def init_features(init, adjs=None, nfeat=10): | ||||
|  | ||||
|     if init=='zeros': | ||||
|         feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device) | ||||
|     elif init=='ones': | ||||
|         feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device) | ||||
|     elif init=='deg': | ||||
|         feature = adjs.sum(dim=-1).to(torch.long) | ||||
|         num_classes = nfeat | ||||
|         try: | ||||
|             feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32) | ||||
|         except: | ||||
|             print(feature.max().item()) | ||||
|             raise NotImplementedError(f'max_feat_num mismatch') | ||||
|     else: | ||||
|         raise NotImplementedError(f'{init} not implemented') | ||||
|  | ||||
|     flags = node_flags(adjs) | ||||
|  | ||||
|     return mask_x(feature, flags) | ||||
|  | ||||
|  | ||||
| # -------- Sample initial flags tensor from the training graph set -------- | ||||
| def init_flags(graph_list, config, batch_size=None): | ||||
|     if batch_size is None: | ||||
|         batch_size = config.data.batch_size | ||||
|     max_node_num = config.data.max_node_num | ||||
|     graph_tensor = graphs_to_tensor(graph_list, max_node_num) | ||||
|     idx = np.random.randint(0, len(graph_list), batch_size) | ||||
|     flags = node_flags(graph_tensor[idx]) | ||||
|  | ||||
|     return flags | ||||
|  | ||||
|  | ||||
| # -------- Generate noise -------- | ||||
| def gen_noise(x, flags, sym=True): | ||||
|     z = torch.randn_like(x) | ||||
|     if sym: | ||||
|         z = z.triu(1) | ||||
|         z = z + z.transpose(-1,-2) | ||||
|         z = mask_adjs(z, flags) | ||||
|     else: | ||||
|         z = mask_x(z, flags) | ||||
|     return z | ||||
|  | ||||
|  | ||||
| # -------- Quantize generated graphs -------- | ||||
| def quantize(adjs, thr=0.5): | ||||
|     adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs)) | ||||
|     return adjs_ | ||||
|  | ||||
|  | ||||
| # -------- Quantize generated molecules -------- | ||||
| # adjs: 32 x 9 x 9 | ||||
| def quantize_mol(adjs):                          | ||||
|     if type(adjs).__name__ == 'Tensor': | ||||
|         adjs = adjs.detach().cpu() | ||||
|     else: | ||||
|         adjs = torch.tensor(adjs) | ||||
|     adjs[adjs >= 2.5] = 3 | ||||
|     adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2 | ||||
|     adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1 | ||||
|     adjs[adjs < 0.5] = 0 | ||||
|     return np.array(adjs.to(torch.int64)) | ||||
|  | ||||
|  | ||||
| def adjs_to_graphs(adjs, is_cuda=False): | ||||
|     graph_list = [] | ||||
|     for adj in adjs: | ||||
|         if is_cuda: | ||||
|             adj = adj.detach().cpu().numpy() | ||||
|         G = nx.from_numpy_matrix(adj) | ||||
|         G.remove_edges_from(nx.selfloop_edges(G)) | ||||
|         G.remove_nodes_from(list(nx.isolates(G))) | ||||
|         if G.number_of_nodes() < 1: | ||||
|             G.add_node(1) | ||||
|         graph_list.append(G) | ||||
|     return graph_list | ||||
|  | ||||
|  | ||||
| # -------- Check if the adjacency matrices are symmetric -------- | ||||
| def check_sym(adjs, print_val=False): | ||||
|     sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2]) | ||||
|     if not sym_error < 1e-2: | ||||
|         raise ValueError(f'Not symmetric: {sym_error:.4e}') | ||||
|     if print_val: | ||||
|         print(f'{sym_error:.4e}') | ||||
|  | ||||
|  | ||||
| # -------- Create higher order adjacency matrices -------- | ||||
| def pow_tensor(x, cnum): | ||||
|     # x : B x N x N | ||||
|     x_ = x.clone() | ||||
|     xc = [x.unsqueeze(1)] | ||||
|     for _ in range(cnum-1): | ||||
|         x_ = torch.bmm(x_, x) | ||||
|         xc.append(x_.unsqueeze(1)) | ||||
|     xc = torch.cat(xc, dim=1) | ||||
|  | ||||
|     return xc | ||||
|  | ||||
|  | ||||
| # -------- Create padded adjacency matrices -------- | ||||
| def pad_adjs(ori_adj, node_number): | ||||
|     a = ori_adj | ||||
|     ori_len = a.shape[-1] | ||||
|     if ori_len == node_number: | ||||
|         return a | ||||
|     if ori_len > node_number: | ||||
|         raise ValueError(f'ori_len {ori_len} > node_number {node_number}') | ||||
|     a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1) | ||||
|     a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0) | ||||
|     return a | ||||
|  | ||||
|  | ||||
| def graphs_to_tensor(graph_list, max_node_num): | ||||
|     adjs_list = [] | ||||
|     max_node_num = max_node_num | ||||
|  | ||||
|     for g in graph_list: | ||||
|         assert isinstance(g, nx.Graph) | ||||
|         node_list = [] | ||||
|         for v, feature in g.nodes.data('feature'): | ||||
|             node_list.append(v) | ||||
|  | ||||
|         adj = nx.to_numpy_matrix(g, nodelist=node_list) | ||||
|         padded_adj = pad_adjs(adj, node_number=max_node_num) | ||||
|         adjs_list.append(padded_adj) | ||||
|  | ||||
|     del graph_list | ||||
|  | ||||
|     adjs_np = np.asarray(adjs_list) | ||||
|     del adjs_list | ||||
|  | ||||
|     adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32) | ||||
|     del adjs_np | ||||
|  | ||||
|     return adjs_tensor  | ||||
|  | ||||
|  | ||||
| def graphs_to_adj(graph, max_node_num): | ||||
|     max_node_num = max_node_num | ||||
|  | ||||
|     assert isinstance(graph, nx.Graph) | ||||
|     node_list = [] | ||||
|     for v, feature in graph.nodes.data('feature'): | ||||
|         node_list.append(v) | ||||
|  | ||||
|     adj = nx.to_numpy_matrix(graph, nodelist=node_list) | ||||
|     padded_adj = pad_adjs(adj, node_number=max_node_num) | ||||
|  | ||||
|     adj = torch.tensor(padded_adj, dtype=torch.float32) | ||||
|     del padded_adj | ||||
|  | ||||
|     return adj | ||||
|  | ||||
|  | ||||
| def node_feature_to_matrix(x): | ||||
|     """ | ||||
|     :param x:  BS x N x F | ||||
|     :return: | ||||
|     x_pair: BS x N x N x 2F | ||||
|     """ | ||||
|     x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1)  # BS x N x N x F | ||||
|     x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1)  # BS x N x N x 2F | ||||
|  | ||||
|     return x_pair | ||||
							
								
								
									
										153
									
								
								MobileNetV3/models/GDSS/layers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								MobileNetV3/models/GDSS/layers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | ||||
| import torch | ||||
| from torch.nn import Parameter | ||||
| import torch.nn.functional as F | ||||
| import math | ||||
| from typing import Any | ||||
|  | ||||
|  | ||||
| def glorot(tensor): | ||||
|     if tensor is not None: | ||||
|         stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) | ||||
|         tensor.data.uniform_(-stdv, stdv) | ||||
|  | ||||
| def zeros(tensor): | ||||
|     if tensor is not None: | ||||
|         tensor.data.fill_(0) | ||||
|  | ||||
| def reset(value: Any): | ||||
|     if hasattr(value, 'reset_parameters'): | ||||
|         value.reset_parameters() | ||||
|     else: | ||||
|         for child in value.children() if hasattr(value, 'children') else []: | ||||
|             reset(child) | ||||
|  | ||||
| # -------- GCN layer -------- | ||||
| class DenseGCNConv(torch.nn.Module): | ||||
|     r"""See :class:`torch_geometric.nn.conv.GCNConv`. | ||||
|     """ | ||||
|     def __init__(self, in_channels, out_channels, improved=False, bias=True): | ||||
|         super(DenseGCNConv, self).__init__() | ||||
|  | ||||
|         self.in_channels = in_channels | ||||
|         self.out_channels = out_channels | ||||
|         self.improved = improved | ||||
|  | ||||
|         self.weight = Parameter(torch.Tensor(self.in_channels, out_channels)) | ||||
|  | ||||
|         if bias: | ||||
|             self.bias = Parameter(torch.Tensor(out_channels)) | ||||
|         else: | ||||
|             self.register_parameter('bias', None) | ||||
|  | ||||
|         self.reset_parameters() | ||||
|  | ||||
|     def reset_parameters(self): | ||||
|         glorot(self.weight) | ||||
|         zeros(self.bias) | ||||
|  | ||||
|  | ||||
|     def forward(self, x, adj, mask=None, add_loop=True): | ||||
|         r""" | ||||
|         Args: | ||||
|             x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B | ||||
|                 \times N \times F}`, with batch-size :math:`B`, (maximum) | ||||
|                 number of nodes :math:`N` for each graph, and feature | ||||
|                 dimension :math:`F`. | ||||
|             adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B | ||||
|                 \times N \times N}`. The adjacency tensor is broadcastable in | ||||
|                 the batch dimension, resulting in a shared adjacency matrix for | ||||
|                 the complete batch. | ||||
|             mask (BoolTensor, optional): Mask matrix | ||||
|                 :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating | ||||
|                 the valid nodes for each graph. (default: :obj:`None`) | ||||
|             add_loop (bool, optional): If set to :obj:`False`, the layer will | ||||
|                 not automatically add self-loops to the adjacency matrices. | ||||
|                 (default: :obj:`True`) | ||||
|         """ | ||||
|         x = x.unsqueeze(0) if x.dim() == 2 else x | ||||
|         adj = adj.unsqueeze(0) if adj.dim() == 2 else adj | ||||
|         B, N, _ = adj.size() | ||||
|  | ||||
|         if add_loop: | ||||
|             adj = adj.clone() | ||||
|             idx = torch.arange(N, dtype=torch.long, device=adj.device) | ||||
|             adj[:, idx, idx] = 1 if not self.improved else 2 | ||||
|  | ||||
|         out = torch.matmul(x, self.weight) | ||||
|         deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5) | ||||
|  | ||||
|         adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2) | ||||
|         out = torch.matmul(adj, out) | ||||
|  | ||||
|         if self.bias is not None: | ||||
|             out = out + self.bias | ||||
|  | ||||
|         if mask is not None: | ||||
|             out = out * mask.view(B, N, 1).to(x.dtype) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, | ||||
|                                    self.out_channels) | ||||
|  | ||||
| # -------- MLP layer -------- | ||||
| class MLP(torch.nn.Module): | ||||
|     def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu): | ||||
|         """ | ||||
|             num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. | ||||
|             input_dim: dimensionality of input features | ||||
|             hidden_dim: dimensionality of hidden units at ALL layers | ||||
|             output_dim: number of classes for prediction | ||||
|             num_classes: the number of classes of input, to be treated with different gains and biases, | ||||
|                     (see the definition of class `ConditionalLayer1d`) | ||||
|         """ | ||||
|  | ||||
|         super(MLP, self).__init__() | ||||
|  | ||||
|         self.linear_or_not = True  # default is linear model | ||||
|         self.num_layers = num_layers | ||||
|         self.use_bn = use_bn | ||||
|         self.activate_func = activate_func | ||||
|  | ||||
|         if num_layers < 1: | ||||
|             raise ValueError("number of layers should be positive!") | ||||
|         elif num_layers == 1: | ||||
|             # Linear model | ||||
|             self.linear = torch.nn.Linear(input_dim, output_dim) | ||||
|         else: | ||||
|             # Multi-layer model | ||||
|             self.linear_or_not = False | ||||
|             self.linears = torch.nn.ModuleList() | ||||
|  | ||||
|             self.linears.append(torch.nn.Linear(input_dim, hidden_dim)) | ||||
|             for layer in range(num_layers - 2): | ||||
|                 self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim)) | ||||
|             self.linears.append(torch.nn.Linear(hidden_dim, output_dim)) | ||||
|  | ||||
|             if self.use_bn: | ||||
|                 self.batch_norms = torch.nn.ModuleList() | ||||
|                 for layer in range(num_layers - 1): | ||||
|                     self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim)) | ||||
|  | ||||
|  | ||||
|     def forward(self, x): | ||||
|         """ | ||||
|         :param x: [num_classes * batch_size, N, F_i], batch of node features | ||||
|             note that in self.cond_layers[layer], | ||||
|             `x` is splited into `num_classes` groups in dim=0, | ||||
|             and then treated with different gains and biases | ||||
|         """ | ||||
|         if self.linear_or_not: | ||||
|             # If linear model | ||||
|             return self.linear(x) | ||||
|         else: | ||||
|             # If MLP | ||||
|             h = x | ||||
|             for layer in range(self.num_layers - 1): | ||||
|                 h = self.linears[layer](h) | ||||
|                 if self.use_bn: | ||||
|                     h = self.batch_norms[layer](h) | ||||
|                 h = self.activate_func(h) | ||||
|             return self.linears[self.num_layers - 1](h) | ||||
							
								
								
									
										103
									
								
								MobileNetV3/models/GDSS/scorenetx.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								MobileNetV3/models/GDSS/scorenetx.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from models.GDSS.layers import DenseGCNConv, MLP | ||||
| from .graph_utils import mask_x, pow_tensor | ||||
| from .attention import AttentionLayer | ||||
| from .. import utils | ||||
|  | ||||
| @utils.register_model(name='ScoreNetworkX') | ||||
| class ScoreNetworkX(torch.nn.Module): | ||||
|  | ||||
|     # def __init__(self, max_feat_num, depth, nhid): | ||||
|     def __init__(self, config): | ||||
|  | ||||
|         super(ScoreNetworkX, self).__init__() | ||||
|  | ||||
|         self.nfeat = config.data.n_vocab | ||||
|         self.depth = config.model.depth | ||||
|         self.nhid = config.model.nhid | ||||
|  | ||||
|         self.layers = torch.nn.ModuleList() | ||||
|         for _ in range(self.depth): | ||||
|             if _ == 0: | ||||
|                 self.layers.append(DenseGCNConv(self.nfeat, self.nhid)) | ||||
|             else: | ||||
|                 self.layers.append(DenseGCNConv(self.nhid, self.nhid)) | ||||
|  | ||||
|         self.fdim = self.nfeat + self.depth * self.nhid | ||||
|         self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,  | ||||
|                             use_bn=False, activate_func=F.elu) | ||||
|  | ||||
|         self.activation = torch.tanh | ||||
|  | ||||
|     def forward(self, x, time_cond, maskX, flags=None): | ||||
|  | ||||
|         x_list = [x] | ||||
|         for _ in range(self.depth): | ||||
|             x = self.layers[_](x, maskX) | ||||
|             x = self.activation(x) | ||||
|             x_list.append(x) | ||||
|  | ||||
|         xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H) | ||||
|         out_shape = (x.shape[0], x.shape[1], -1) | ||||
|         x = self.final(xs).view(*out_shape) | ||||
|  | ||||
|         x = mask_x(x, flags) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| @utils.register_model(name='ScoreNetworkX_GMH') | ||||
| class ScoreNetworkX_GMH(torch.nn.Module): | ||||
|     # def __init__(self, max_feat_num, depth, nhid, num_linears, | ||||
|     #              c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|          | ||||
|         self.max_feat_num = config.data.n_vocab | ||||
|         self.depth = config.model.depth | ||||
|         self.nhid = config.model.nhid | ||||
|         self.c_init = config.model.c_init | ||||
|         self.c_hid = config.model.c_hid | ||||
|         self.c_final = config.model.c_final | ||||
|         self.num_linears = config.model.num_linears | ||||
|         self.num_heads = config.model.num_heads | ||||
|         self.conv = config.model.conv | ||||
|         self.adim = config.model.adim | ||||
|          | ||||
|         self.layers = torch.nn.ModuleList() | ||||
|         for _ in range(self.depth): | ||||
|             if _ == 0: | ||||
|                 self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,  | ||||
|                                                   self.nhid, self.nhid, self.c_init,  | ||||
|                                                   self.c_hid, self.num_heads, self.conv)) | ||||
|             elif _ == self.depth - 1: | ||||
|                 self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,  | ||||
|                                                   self.nhid, self.c_hid,  | ||||
|                                                   self.c_final, self.num_heads, self.conv)) | ||||
|             else: | ||||
|                 self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim, | ||||
|                                                   self.nhid, self.c_hid,  | ||||
|                                                   self.c_hid, self.num_heads, self.conv)) | ||||
|  | ||||
|         fdim = self.max_feat_num + self.depth * self.nhid | ||||
|         self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,  | ||||
|                          use_bn=False, activate_func=F.elu) | ||||
|  | ||||
|         self.activation = torch.tanh | ||||
|  | ||||
|     def forward(self, x, time_cond, maskX, flags=None): | ||||
|         adjc = pow_tensor(maskX, self.c_init) | ||||
|  | ||||
|         x_list = [x] | ||||
|         for _ in range(self.depth): | ||||
|             x, adjc = self.layers[_](x, adjc, flags) | ||||
|             x = self.activation(x) | ||||
|             x_list.append(x) | ||||
|  | ||||
|         xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H) | ||||
|         out_shape = (x.shape[0], x.shape[1], -1) | ||||
|         x = self.final(xs).view(*out_shape) | ||||
|         x = mask_x(x, flags) | ||||
|  | ||||
|         return x | ||||
		Reference in New Issue
	
	Block a user