diff --git a/graph_dit/models/transformer.py b/graph_dit/models/transformer.py index 4fcb2f2..15d77e9 100644 --- a/graph_dit/models/transformer.py +++ b/graph_dit/models/transformer.py @@ -87,7 +87,7 @@ class Denoiser(nn.Module): def forward(self, x, e, node_mask, y, t, unconditioned): print("Denoiser Forward") - print(x.shape, e.shape, y.shape, t.shape, unconditioned) + # print(x.shape, e.shape, y.shape, t.shape, unconditioned) force_drop_id = torch.zeros_like(y.sum(-1)) # drop the nan values force_drop_id[torch.isnan(y.sum(-1))] = 1 @@ -98,32 +98,32 @@ class Denoiser(nn.Module): # bs = batch size, n = number of nodes bs, n, _ = x.size() x = torch.cat([x, e.reshape(bs, n, -1)], dim=-1) - print("X after concat with E") - print(x.shape) + # print("X after concat with E") + # print(x.shape) # self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) x = self.x_embedder(x) - print("X after x_embedder") - print(x.shape) + # print("X after x_embedder") + # print(x.shape) # self.t_embedder = TimestepEmbedder(hidden_size) c1 = self.t_embedder(t) - print("C1 after t_embedder") - print(c1.shape) + # print("C1 after t_embedder") + # print(c1.shape) for i in range(1, self.ydim): if i == 1: c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) else: c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) - print("C2 after y_embedding_list") - print(c2.shape) - print("C1 + C2") + # print("C2 after y_embedding_list") + # print(c2.shape) + # print("C1 + C2") c = c1 + c2 - print(c.shape) + # print(c.shape) for i, block in enumerate(self.encoders): x = block(x, c, node_mask) - print("X after block") - print(x.shape) + # print("X after block") + # print(x.shape) # X: B * N * dx, E: B * N * N * de X, E, y = self.out_layer(x, x_in, e_in, c, t, node_mask)