import torch.nn as nn import torch import functools from torch_geometric.utils import dense_to_sparse from . import utils, layers, gnns get_act = layers.get_act conv1x1 = layers.conv1x1 @utils.register_model(name='PGSN') class PGSN(nn.Module): """Position enhanced graph score network.""" def __init__(self, config): super().__init__() self.config = config self.act = act = get_act(config) # get model construction paras self.nf = nf = config.model.nf self.num_gnn_layers = num_gnn_layers = config.model.num_gnn_layers dropout = config.model.dropout self.embedding_type = embedding_type = config.model.embedding_type.lower() self.rw_depth = rw_depth = config.model.rw_depth self.edge_th = config.model.edge_th modules = [] # timestep/noise_level embedding; only for continuous training if embedding_type == 'positional': embed_dim = nf else: raise ValueError(f'embedding type {embedding_type} unknown.') # timestep embedding layers modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(nn.Linear(nf * 4, nf * 4)) # graph size condition embedding self.size_cond = size_cond = config.model.size_cond if size_cond: self.size_onehot = functools.partial(nn.functional.one_hot, num_classes=config.data.max_node + 1) modules.append(nn.Linear(config.data.max_node + 1, nf * 4)) modules.append(nn.Linear(nf * 4, nf * 4)) channels = config.data.num_channels assert channels == 1, "Without edge features." # degree onehot self.degree_max = self.config.data.max_node // 2 self.degree_onehot = functools.partial( nn.functional.one_hot, num_classes=self.degree_max + 1) # project edge features modules.append(conv1x1(channels, nf // 2)) modules.append(conv1x1(rw_depth + 1, nf // 2)) # project node features self.x_ch = nf self.pos_ch = nf // 2 modules.append(nn.Linear(self.degree_max + 1, self.x_ch)) modules.append(nn.Linear(rw_depth, self.pos_ch)) # GNN modules.append(gnns.pos_gnn(act, self.x_ch, self.pos_ch, nf, config.data.max_node, config.model.graph_layer, num_gnn_layers, heads=config.model.heads, edge_dim=nf//2, temb_dim=nf * 4, dropout=dropout, attn_clamp=config.model.attn_clamp)) # output modules.append(conv1x1(nf // 2, nf // 2)) modules.append(conv1x1(nf // 2, channels)) self.all_modules = nn.ModuleList(modules) def forward(self, x, time_cond, *args, **kwargs): mask = kwargs['mask'] modules = self.all_modules m_idx = 0 # Sinusoidal positional embeddings timesteps = time_cond temb = layers.get_timestep_embedding(timesteps, self.nf) # time embedding temb = modules[m_idx](temb) # [32, 512] m_idx += 1 temb = modules[m_idx](self.act(temb)) # [32, 512] m_idx += 1 if self.size_cond: with torch.no_grad(): node_mask = utils.mask_adj2node(mask.squeeze(1)) # [B, N] num_node = torch.sum(node_mask, dim=-1) # [B] num_node = self.size_onehot(num_node.to(torch.long)).to(torch.float) num_node_emb = modules[m_idx](num_node) m_idx += 1 num_node_emb = modules[m_idx](self.act(num_node_emb)) m_idx += 1 temb = temb + num_node_emb if not self.config.data.centered: # rescale the input data to [-1, 1] x = x * 2. - 1. with torch.no_grad(): # continuous-valued graph adjacency matrices cont_adj = ((x + 1.) / 2.).clone() cont_adj = (cont_adj * mask).squeeze(1) # [B, N, N] cont_adj = cont_adj.clamp(min=0., max=1.) if self.edge_th > 0.: cont_adj[cont_adj < self.edge_th] = 0. # discretized graph adjacency matrices adj = x.squeeze(1).clone() # [B, N, N] adj[adj >= 0.] = 1. adj[adj < 0.] = 0. adj = adj * mask.squeeze(1) # extract RWSE and Shortest-Path Distance x_pos, spd_onehot = utils.get_rw_feat(self.rw_depth, adj) # x_pos: [32, 20, 16], spd_onehot: [32, 17, 20, 20] # edge [B, N, N, F] dense_edge_ori = modules[m_idx](x).permute(0, 2, 3, 1) # [32, 20, 20, 64] m_idx += 1 dense_edge_spd = modules[m_idx](spd_onehot).permute(0, 2, 3, 1) # [32, 20, 20, 64] m_idx += 1 # Use Degree as node feature x_degree = torch.sum(cont_adj, dim=-1) # [B, N] # [32, 20] x_degree = x_degree.clamp(max=float(self.degree_max)) # [B, N] # [32, 20] x_degree = self.degree_onehot(x_degree.to(torch.long)).to(torch.float) # [B, N, max_node] # [32, 20, 11] x_degree = modules[m_idx](x_degree) # projection layer [B, N, nf] # [32, 20, 128] m_idx += 1 import pdb; pdb.set_trace() # pos encoding # x_pos: [32, 20, 16] x_pos = modules[m_idx](x_pos) # [32, 20, 64] m_idx += 1 # Dense to sparse node [BxN, -1] x_degree = x_degree.reshape(-1, self.x_ch) # [640, 128] x_pos = x_pos.reshape(-1, self.pos_ch) # [640, 64] dense_index = cont_adj.nonzero(as_tuple=True) edge_index, _ = dense_to_sparse(cont_adj) # [2, 5386] # Run GNN layers h_dense_edge = modules[m_idx](x_degree, x_pos, edge_index, dense_edge_ori, dense_edge_spd, dense_index, temb) m_idx += 1 import pdb; pdb.set_trace() # Output h = self.act(modules[m_idx](self.act(h_dense_edge))) m_idx += 1 import pdb; pdb.set_trace() h = modules[m_idx](h) m_idx += 1 import pdb; pdb.set_trace() # make edge estimation symmetric h = (h + h.transpose(2, 3)) / 2. * mask import pdb; pdb.set_trace() assert m_idx == len(modules) return h