Compare commits

...

2 Commits
0.01 ... main

2 changed files with 5 additions and 4 deletions

View File

@ -716,7 +716,6 @@ class DataInfos(AbstractDatasetInfos):
graphs.append((adj_matrix, ops))
meta_dict = graphs_to_json(graphs, 'nasbench-201')
self.base_path = base_path
self.active_atoms = meta_dict['active_atoms']
self.max_n_nodes = meta_dict['max_node']

View File

@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
@torch.no_grad()
def validation_step(self, data, i):
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
dense_data = dense_data.mask(node_mask)
dense_data = dense_data.mask(node_mask, collapse=False)
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = self.forward(noisy_data)
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
@ -444,9 +444,11 @@ class Graph_DiT(pl.LightningModule):
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
print(f"alpha_t_bar.shape {alpha_t_bar.shape}")
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
print(f"E.shape {E.shape}")
print(f"X.shape {X.shape}")
bs, n, d = X.shape
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
prob_all = X_all @ Qtb.X