From dd31fda8d5f3fc0da9f34a75be2fed21912cb5fb Mon Sep 17 00:00:00 2001 From: mhz Date: Mon, 1 Jul 2024 10:03:40 +0200 Subject: [PATCH] comment some output statements --- graph_dit/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/graph_dit/utils.py b/graph_dit/utils.py index 23776ea..589b1ed 100644 --- a/graph_dit/utils.py +++ b/graph_dit/utils.py @@ -46,13 +46,17 @@ def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False): def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None): + # print(f"to dense X: {x.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}") X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes) # node_mask = node_mask.float() edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr) if max_num_nodes is None: max_num_nodes = X.size(1) + # print(f"to dense X: {X.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}") E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes) E = encode_no_edge(E) + # print(f"to dense X: {X.shape}, edge_index: {edge_index.shape}, edge_attr: {edge_attr.shape}, batch: {batch}, max_num_nodes: {max_num_nodes}") + # print(f"to dense X: {X.shape}, E: {E.shape}, batch: {batch}, lenE: {len(E)}") return PlaceHolder(X=X, E=E, y=None), node_mask @@ -119,6 +123,7 @@ class PlaceHolder: x_mask = node_mask.unsqueeze(-1) # bs, n, 1 e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 + # print(f"mask X: {self.X.shape}, E: {self.E.shape}, node_mask: {node_mask.shape}, x_mask: {x_mask.shape}, e_mask1: {e_mask1.shape}, e_mask2: {e_mask2.shape}") if collapse: self.X = torch.argmax(self.X, dim=-1) @@ -127,8 +132,13 @@ class PlaceHolder: self.X[node_mask == 0] = - 1 self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 else: + # print(f"X: {self.X.shape}, E: {self.E.shape}") + # print(f"X: {self.X}, E: {self.E}") + # print(f"x_mask: {x_mask}, e_mask1: {e_mask1}, e_mask2: {e_mask2}") self.X = self.X * x_mask self.E = self.E * e_mask1 * e_mask2 + # print(f"X: {self.X.shape}, E: {self.E.shape}") + # print(f"X: {self.X}, E: {self.E}") assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) return self