diff --git a/graph_dit/utils.py b/graph_dit/utils.py index 589b1ed..7368dbb 100644 --- a/graph_dit/utils.py +++ b/graph_dit/utils.py @@ -139,7 +139,7 @@ class PlaceHolder: self.E = self.E * e_mask1 * e_mask2 # print(f"X: {self.X.shape}, E: {self.E.shape}") # print(f"X: {self.X}, E: {self.E}") - assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) + # assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) return self