diff --git a/graph_dit/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py index 2031409..2c6bf90 100644 --- a/graph_dit/datasets/abstract_dataset.py +++ b/graph_dit/datasets/abstract_dataset.py @@ -116,7 +116,7 @@ class AbstractDatasetInfos: def compute_input_output_dims(self, datamodule): example_batch = datamodule.example_batch() example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] - example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() + example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() self.input_dims = {'X': example_batch_x.size(1), 'E': example_batch_edge_attr.size(1),