172 lines
6.1 KiB
Python
172 lines
6.1 KiB
Python
|
import torch.nn as nn
|
||
|
import torch
|
||
|
import functools
|
||
|
from torch_geometric.utils import dense_to_sparse
|
||
|
|
||
|
from . import utils, layers, gnns
|
||
|
|
||
|
get_act = layers.get_act
|
||
|
conv1x1 = layers.conv1x1
|
||
|
|
||
|
|
||
|
@utils.register_model(name='PGSN')
|
||
|
class PGSN(nn.Module):
|
||
|
"""Position enhanced graph score network."""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.config = config
|
||
|
self.act = act = get_act(config)
|
||
|
|
||
|
# get model construction paras
|
||
|
self.nf = nf = config.model.nf
|
||
|
self.num_gnn_layers = num_gnn_layers = config.model.num_gnn_layers
|
||
|
dropout = config.model.dropout
|
||
|
self.embedding_type = embedding_type = config.model.embedding_type.lower()
|
||
|
self.rw_depth = rw_depth = config.model.rw_depth
|
||
|
self.edge_th = config.model.edge_th
|
||
|
|
||
|
modules = []
|
||
|
# timestep/noise_level embedding; only for continuous training
|
||
|
if embedding_type == 'positional':
|
||
|
embed_dim = nf
|
||
|
else:
|
||
|
raise ValueError(f'embedding type {embedding_type} unknown.')
|
||
|
|
||
|
# timestep embedding layers
|
||
|
modules.append(nn.Linear(embed_dim, nf * 4))
|
||
|
modules.append(nn.Linear(nf * 4, nf * 4))
|
||
|
|
||
|
# graph size condition embedding
|
||
|
self.size_cond = size_cond = config.model.size_cond
|
||
|
if size_cond:
|
||
|
self.size_onehot = functools.partial(nn.functional.one_hot, num_classes=config.data.max_node + 1)
|
||
|
modules.append(nn.Linear(config.data.max_node + 1, nf * 4))
|
||
|
modules.append(nn.Linear(nf * 4, nf * 4))
|
||
|
|
||
|
channels = config.data.num_channels
|
||
|
assert channels == 1, "Without edge features."
|
||
|
|
||
|
# degree onehot
|
||
|
self.degree_max = self.config.data.max_node // 2
|
||
|
self.degree_onehot = functools.partial(
|
||
|
nn.functional.one_hot,
|
||
|
num_classes=self.degree_max + 1)
|
||
|
|
||
|
# project edge features
|
||
|
modules.append(conv1x1(channels, nf // 2))
|
||
|
modules.append(conv1x1(rw_depth + 1, nf // 2))
|
||
|
|
||
|
# project node features
|
||
|
self.x_ch = nf
|
||
|
self.pos_ch = nf // 2
|
||
|
modules.append(nn.Linear(self.degree_max + 1, self.x_ch))
|
||
|
modules.append(nn.Linear(rw_depth, self.pos_ch))
|
||
|
|
||
|
# GNN
|
||
|
modules.append(gnns.pos_gnn(act, self.x_ch, self.pos_ch, nf, config.data.max_node,
|
||
|
config.model.graph_layer, num_gnn_layers,
|
||
|
heads=config.model.heads, edge_dim=nf//2, temb_dim=nf * 4,
|
||
|
dropout=dropout, attn_clamp=config.model.attn_clamp))
|
||
|
|
||
|
# output
|
||
|
modules.append(conv1x1(nf // 2, nf // 2))
|
||
|
modules.append(conv1x1(nf // 2, channels))
|
||
|
|
||
|
self.all_modules = nn.ModuleList(modules)
|
||
|
|
||
|
def forward(self, x, time_cond, *args, **kwargs):
|
||
|
mask = kwargs['mask']
|
||
|
modules = self.all_modules
|
||
|
m_idx = 0
|
||
|
|
||
|
# Sinusoidal positional embeddings
|
||
|
timesteps = time_cond
|
||
|
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
||
|
|
||
|
# time embedding
|
||
|
temb = modules[m_idx](temb) # [32, 512]
|
||
|
m_idx += 1
|
||
|
temb = modules[m_idx](self.act(temb)) # [32, 512]
|
||
|
m_idx += 1
|
||
|
|
||
|
if self.size_cond:
|
||
|
with torch.no_grad():
|
||
|
node_mask = utils.mask_adj2node(mask.squeeze(1)) # [B, N]
|
||
|
num_node = torch.sum(node_mask, dim=-1) # [B]
|
||
|
num_node = self.size_onehot(num_node.to(torch.long)).to(torch.float)
|
||
|
num_node_emb = modules[m_idx](num_node)
|
||
|
m_idx += 1
|
||
|
num_node_emb = modules[m_idx](self.act(num_node_emb))
|
||
|
m_idx += 1
|
||
|
temb = temb + num_node_emb
|
||
|
|
||
|
if not self.config.data.centered:
|
||
|
# rescale the input data to [-1, 1]
|
||
|
x = x * 2. - 1.
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# continuous-valued graph adjacency matrices
|
||
|
cont_adj = ((x + 1.) / 2.).clone()
|
||
|
cont_adj = (cont_adj * mask).squeeze(1) # [B, N, N]
|
||
|
cont_adj = cont_adj.clamp(min=0., max=1.)
|
||
|
if self.edge_th > 0.:
|
||
|
cont_adj[cont_adj < self.edge_th] = 0.
|
||
|
|
||
|
# discretized graph adjacency matrices
|
||
|
adj = x.squeeze(1).clone() # [B, N, N]
|
||
|
adj[adj >= 0.] = 1.
|
||
|
adj[adj < 0.] = 0.
|
||
|
adj = adj * mask.squeeze(1)
|
||
|
|
||
|
# extract RWSE and Shortest-Path Distance
|
||
|
x_pos, spd_onehot = utils.get_rw_feat(self.rw_depth, adj)
|
||
|
# x_pos: [32, 20, 16], spd_onehot: [32, 17, 20, 20]
|
||
|
|
||
|
# edge [B, N, N, F]
|
||
|
dense_edge_ori = modules[m_idx](x).permute(0, 2, 3, 1) # [32, 20, 20, 64]
|
||
|
m_idx += 1
|
||
|
dense_edge_spd = modules[m_idx](spd_onehot).permute(0, 2, 3, 1) # [32, 20, 20, 64]
|
||
|
m_idx += 1
|
||
|
|
||
|
# Use Degree as node feature
|
||
|
x_degree = torch.sum(cont_adj, dim=-1) # [B, N] # [32, 20]
|
||
|
x_degree = x_degree.clamp(max=float(self.degree_max)) # [B, N] # [32, 20]
|
||
|
x_degree = self.degree_onehot(x_degree.to(torch.long)).to(torch.float) # [B, N, max_node] # [32, 20, 11]
|
||
|
x_degree = modules[m_idx](x_degree) # projection layer [B, N, nf] # [32, 20, 128]
|
||
|
m_idx += 1
|
||
|
import pdb; pdb.set_trace()
|
||
|
|
||
|
# pos encoding
|
||
|
# x_pos: [32, 20, 16]
|
||
|
x_pos = modules[m_idx](x_pos) # [32, 20, 64]
|
||
|
m_idx += 1
|
||
|
|
||
|
# Dense to sparse node [BxN, -1]
|
||
|
x_degree = x_degree.reshape(-1, self.x_ch) # [640, 128]
|
||
|
x_pos = x_pos.reshape(-1, self.pos_ch) # [640, 64]
|
||
|
dense_index = cont_adj.nonzero(as_tuple=True)
|
||
|
edge_index, _ = dense_to_sparse(cont_adj) # [2, 5386]
|
||
|
|
||
|
# Run GNN layers
|
||
|
h_dense_edge = modules[m_idx](x_degree, x_pos, edge_index, dense_edge_ori, dense_edge_spd, dense_index, temb)
|
||
|
m_idx += 1
|
||
|
import pdb; pdb.set_trace()
|
||
|
|
||
|
# Output
|
||
|
h = self.act(modules[m_idx](self.act(h_dense_edge)))
|
||
|
m_idx += 1
|
||
|
import pdb; pdb.set_trace()
|
||
|
h = modules[m_idx](h)
|
||
|
m_idx += 1
|
||
|
import pdb; pdb.set_trace()
|
||
|
|
||
|
# make edge estimation symmetric
|
||
|
h = (h + h.transpose(2, 3)) / 2. * mask
|
||
|
import pdb; pdb.set_trace()
|
||
|
|
||
|
assert m_idx == len(modules)
|
||
|
|
||
|
return h
|