diff --git a/graph_dit/datasets/abstract_dataset.py b/graph_dit/datasets/abstract_dataset.py index 63f1ea5..2031409 100644 --- a/graph_dit/datasets/abstract_dataset.py +++ b/graph_dit/datasets/abstract_dataset.py @@ -127,4 +127,19 @@ class AbstractDatasetInfos: print('input dims') print(self.input_dims) print('output dims') + print(self.output_dims) + def compute_graph_input_output_dims(self, datamodule): + example_batch = datamodule.example_batch() + example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index] + example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() + + self.input_dims = {'X': example_batch_x.size(1), + 'E': example_batch_edge_attr.size(1), + 'y': example_batch['y'].size(1)} + self.output_dims = {'X': example_batch_x.size(1), + 'E': example_batch_edge_attr.size(1), + 'y': example_batch['y'].size(1)} + print('input dims') + print(self.input_dims) + print('output dims') print(self.output_dims) \ No newline at end of file