From d57575586d7fa0ef8bba2f34cda7e16190fa7861 Mon Sep 17 00:00:00 2001 From: mhz Date: Sun, 30 Jun 2024 16:43:08 +0200 Subject: [PATCH] make the metrics code back --- graph_dit/diffusion_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 9d26ecf..656a695 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -13,11 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL import utils class Graph_DiT(pl.LightningModule): - # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): - def __init__(self, cfg, dataset_infos, visualization_tools): + def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): + # def __init__(self, cfg, dataset_infos, visualization_tools): super().__init__() - # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) + self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) self.test_only = cfg.general.test_only self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) @@ -57,8 +57,8 @@ class Graph_DiT(pl.LightningModule): self.test_E_logp = SumExceptBatchMetric() self.test_y_collection = [] - # self.train_metrics = train_metrics - # self.sampling_metrics = sampling_metrics + self.train_metrics = train_metrics + self.sampling_metrics = sampling_metrics self.visualization_tools = visualization_tools self.max_n_nodes = dataset_infos.max_n_nodes @@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def validation_step(self, data, i): data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) - dense_data = dense_data.mask(node_mask) + dense_data = dense_data.mask(node_mask, collapse=True) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) pred = self.forward(noisy_data) nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)