diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 9d26ecf..8f533e4 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def validation_step(self, data, i): data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) - dense_data = dense_data.mask(node_mask) + dense_data = dense_data.mask(node_mask, collapse=False) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) pred = self.forward(noisy_data) nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) @@ -444,9 +444,11 @@ class Graph_DiT(pl.LightningModule): beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) + print(f"alpha_t_bar.shape {alpha_t_bar.shape}") Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out) - + print(f"E.shape {E.shape}") + print(f"X.shape {X.shape}") bs, n, d = X.shape X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) prob_all = X_all @ Qtb.X