From f5911be7810d11f9acde093be9b8dc085ef808ab Mon Sep 17 00:00:00 2001 From: mhz Date: Sun, 30 Jun 2024 21:09:16 +0200 Subject: [PATCH] some onehot issue --- graph_dit/datasets/abstract_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graph_dit/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py index 2031409..2c6bf90 100644 --- a/graph_dit/datasets/abstract_dataset.py +++ b/graph_dit/datasets/abstract_dataset.py @@ -116,7 +116,7 @@ class AbstractDatasetInfos: def compute_input_output_dims(self, datamodule): example_batch = datamodule.example_batch() example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] - example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() + example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() self.input_dims = {'X': example_batch_x.size(1), 'E': example_batch_edge_attr.size(1),