some onehot issue
This commit is contained in:
		| @@ -116,7 +116,7 @@ class AbstractDatasetInfos: | ||||
|     def compute_input_output_dims(self, datamodule): | ||||
|         example_batch = datamodule.example_batch() | ||||
|         example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float() | ||||
|         example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         self.input_dims = {'X': example_batch_x.size(1),  | ||||
|                            'E': example_batch_edge_attr.size(1),  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user