write graph code for the absctract dataset
This commit is contained in:
parent
14186fa97f
commit
a7f7010da7
@ -118,6 +118,21 @@ class AbstractDatasetInfos:
|
|||||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
|
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
|
||||||
|
|
||||||
|
self.input_dims = {'X': example_batch_x.size(1),
|
||||||
|
'E': example_batch_edge_attr.size(1),
|
||||||
|
'y': example_batch['y'].size(1)}
|
||||||
|
self.output_dims = {'X': example_batch_x.size(1),
|
||||||
|
'E': example_batch_edge_attr.size(1),
|
||||||
|
'y': example_batch['y'].size(1)}
|
||||||
|
print('input dims')
|
||||||
|
print(self.input_dims)
|
||||||
|
print('output dims')
|
||||||
|
print(self.output_dims)
|
||||||
|
def compute_graph_input_output_dims(self, datamodule):
|
||||||
|
example_batch = datamodule.example_batch()
|
||||||
|
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index]
|
||||||
|
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float()
|
||||||
|
|
||||||
self.input_dims = {'X': example_batch_x.size(1),
|
self.input_dims = {'X': example_batch_x.size(1),
|
||||||
'E': example_batch_edge_attr.size(1),
|
'E': example_batch_edge_attr.size(1),
|
||||||
'y': example_batch['y'].size(1)}
|
'y': example_batch['y'].size(1)}
|
||||||
|
Loading…
Reference in New Issue
Block a user