diffusionNAG/MobileNetV3/models/GDSS/scorenetx.py
2024-03-15 14:38:51 +00:00

103 lines
3.8 KiB
Python

import torch
import torch.nn.functional as F
from models.GDSS.layers import DenseGCNConv, MLP
from .graph_utils import mask_x, pow_tensor
from .attention import AttentionLayer
from .. import utils
@utils.register_model(name='ScoreNetworkX')
class ScoreNetworkX(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid):
def __init__(self, config):
super(ScoreNetworkX, self).__init__()
self.nfeat = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(DenseGCNConv(self.nfeat, self.nhid))
else:
self.layers.append(DenseGCNConv(self.nhid, self.nhid))
self.fdim = self.nfeat + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=self.nfeat,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
x_list = [x]
for _ in range(self.depth):
x = self.layers[_](x, maskX)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x
@utils.register_model(name='ScoreNetworkX_GMH')
class ScoreNetworkX_GMH(torch.nn.Module):
# def __init__(self, max_feat_num, depth, nhid, num_linears,
# c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):
def __init__(self, config):
super().__init__()
self.max_feat_num = config.data.n_vocab
self.depth = config.model.depth
self.nhid = config.model.nhid
self.c_init = config.model.c_init
self.c_hid = config.model.c_hid
self.c_final = config.model.c_final
self.num_linears = config.model.num_linears
self.num_heads = config.model.num_heads
self.conv = config.model.conv
self.adim = config.model.adim
self.layers = torch.nn.ModuleList()
for _ in range(self.depth):
if _ == 0:
self.layers.append(AttentionLayer(self.num_linears, self.max_feat_num,
self.nhid, self.nhid, self.c_init,
self.c_hid, self.num_heads, self.conv))
elif _ == self.depth - 1:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_final, self.num_heads, self.conv))
else:
self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim,
self.nhid, self.c_hid,
self.c_hid, self.num_heads, self.conv))
fdim = self.max_feat_num + self.depth * self.nhid
self.final = MLP(num_layers=3, input_dim=fdim, hidden_dim=2*fdim, output_dim=self.max_feat_num,
use_bn=False, activate_func=F.elu)
self.activation = torch.tanh
def forward(self, x, time_cond, maskX, flags=None):
adjc = pow_tensor(maskX, self.c_init)
x_list = [x]
for _ in range(self.depth):
x, adjc = self.layers[_](x, adjc, flags)
x = self.activation(x)
x_list.append(x)
xs = torch.cat(x_list, dim=-1) # B x N x (F + num_layers x H)
out_shape = (x.shape[0], x.shape[1], -1)
x = self.final(xs).view(*out_shape)
x = mask_x(x, flags)
return x