103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from models.GDSS.layers import DenseGCNConv, MLP
|
|
from .graph_utils import mask_x, pow_tensor
|
|
from .attention import AttentionLayer
|
|
from .. import utils
|
|
|
|
@utils.register_model(name='ScoreNetworkX')
|
|
class ScoreNetworkX(torch.nn.Module):
|
|
|
|
# def __init__(self, max_feat_num, depth, nhid):
|
|
def __init__(self, config):
|
|
|
|
super(ScoreNetworkX, self).__init__()
|
|
|
|
self.nfeat = config.data.n_vocab
|
|
self.depth = config.model.depth
|
|
self.nhid = config.model.nhid
|
|
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(self.depth):
|
|
if _ == 0:
|
|
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
|
|
else:
|
|
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
|
|
|
|
self.fdim = self.nfeat + self.depth * self.nhid
|
|
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
|
|
use_bn=False, activate_func=F.elu)
|
|
|
|
self.activation = torch.tanh
|
|
|
|
def forward(self, x, time_cond, maskX, flags=None):
|
|
|
|
x_list = [x]
|
|
for _ in range(self.depth):
|
|
x = self.layers[_](x, maskX)
|
|
x = self.activation(x)
|
|
x_list.append(x)
|
|
|
|
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
|
out_shape = (x.shape[0], x.shape[1], -1)
|
|
x = self.final(xs).view(*out_shape)
|
|
|
|
x = mask_x(x, flags)
|
|
return x
|
|
|
|
|
|
@utils.register_model(name='ScoreNetworkX_GMH')
|
|
class ScoreNetworkX_GMH(torch.nn.Module):
|
|
# def __init__(self, max_feat_num, depth, nhid, num_linears,
|
|
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.max_feat_num = config.data.n_vocab
|
|
self.depth = config.model.depth
|
|
self.nhid = config.model.nhid
|
|
self.c_init = config.model.c_init
|
|
self.c_hid = config.model.c_hid
|
|
self.c_final = config.model.c_final
|
|
self.num_linears = config.model.num_linears
|
|
self.num_heads = config.model.num_heads
|
|
self.conv = config.model.conv
|
|
self.adim = config.model.adim
|
|
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(self.depth):
|
|
if _ == 0:
|
|
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
|
|
self.nhid, self.nhid, self.c_init,
|
|
self.c_hid, self.num_heads, self.conv))
|
|
elif _ == self.depth - 1:
|
|
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
|
self.nhid, self.c_hid,
|
|
self.c_final, self.num_heads, self.conv))
|
|
else:
|
|
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
|
|
self.nhid, self.c_hid,
|
|
self.c_hid, self.num_heads, self.conv))
|
|
|
|
fdim = self.max_feat_num + self.depth * self.nhid
|
|
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
|
|
use_bn=False, activate_func=F.elu)
|
|
|
|
self.activation = torch.tanh
|
|
|
|
def forward(self, x, time_cond, maskX, flags=None):
|
|
adjc = pow_tensor(maskX, self.c_init)
|
|
|
|
x_list = [x]
|
|
for _ in range(self.depth):
|
|
x, adjc = self.layers[_](x, adjc, flags)
|
|
x = self.activation(x)
|
|
x_list.append(x)
|
|
|
|
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
|
|
out_shape = (x.shape[0], x.shape[1], -1)
|
|
x = self.final(xs).view(*out_shape)
|
|
x = mask_x(x, flags)
|
|
|
|
return x |