comment some output statements and record dimension infos
This commit is contained in:
		| @@ -65,10 +65,11 @@ def reverse_tensor(x): | ||||
|  | ||||
| def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True): | ||||
|     ''' Sample features from multinomial distribution with given probabilities (probX, probE, proby) | ||||
|         :param probX: bs, n, dx_out        node features | ||||
|         :param probE: bs, n, n, de_out     edge features | ||||
|         :param proby: bs, dy_out           global features. | ||||
|         :param probX: bs, n, dx_out        node features        1200 8 7 | ||||
|         :param probE: bs, n, n, de_out     edge features        1200 8 8 2 | ||||
|         :param proby: bs, dy_out           global features.     1200 8 | ||||
|     ''' | ||||
|     # print(f"sample_discrete_features in: probX: {probX.shape}, probE: {probE.shape}, node_mask: {node_mask.shape}") | ||||
|     bs, n, _ = probX.shape | ||||
|  | ||||
|     # Noise X | ||||
| @@ -97,8 +98,11 @@ def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True): | ||||
|  | ||||
|     # Sample E | ||||
|     E_t = probE.multinomial(1).reshape(bs, n, n)    # (bs, n, n) | ||||
|     # print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}") | ||||
|     E_t = torch.triu(E_t, diagonal=1) | ||||
|     # print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}") | ||||
|     E_t = (E_t + torch.transpose(E_t, 1, 2)) | ||||
|     # print(f"sample_discrete_features out: X_t: {X_t.shape}, E_t: {E_t.shape}") | ||||
|  | ||||
|     return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)) | ||||
|  | ||||
|   | ||||
| @@ -103,16 +103,25 @@ class MarginalTransition: | ||||
|         self.e_marginals = e_marginals # Dx, De | ||||
|         self.xe_conditions = xe_conditions | ||||
|  | ||||
|         self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx | ||||
|         self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De | ||||
|         self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De | ||||
|         self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx | ||||
|         self.u_x = x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) # 1, Dx, Dx 1 7 7 | ||||
|         self.u_e = e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) # 1, De, De 1 2 2 | ||||
|         self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De 1 7 2 | ||||
|         self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx 1 2 7 | ||||
|         self.u = self.get_union_transition(self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes) # 1, Dx + n*De, Dx + n*De | ||||
|         # print(f"Shape of u_x: {self.u_x.shape}") | ||||
|         # print(f"Shape of u_e: {self.u_e.shape}") | ||||
|         # print(f"Shape of u_xe: {self.u_xe.shape}") | ||||
|         # print(f"Shape of u_ex: {self.u_ex.shape}") | ||||
|         # print(f"Shape of u: {self.u.shape}") | ||||
|  | ||||
|     def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes): | ||||
|         # print(f"before processing Shape of u_e: {u_e.shape}") | ||||
|         # print(f"before processing Shape of u_ex: {u_ex.shape}") | ||||
|         u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de) | ||||
|         u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de) | ||||
|         u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx) | ||||
|         # print(f"After processing Shape of u_ex: {u_ex.shape}") | ||||
|         # print(f"After processing Shape of u_e: {u_e.shape}") | ||||
|         u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de) | ||||
|         u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de) | ||||
|         u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user