diff --git a/graph_dit/models/conditions.py b/graph_dit/models/conditions.py index ba2d4c6..62d0d7b 100644 --- a/graph_dit/models/conditions.py +++ b/graph_dit/models/conditions.py @@ -76,6 +76,8 @@ class CategoricalEmbedder(nn.Module): embeddings = embeddings + noise return embeddings +# 相似的condition cluster起来 +# size class ClusterContinuousEmbedder(nn.Module): def __init__(self, input_size, hidden_size, dropout_prob): super().__init__() @@ -108,6 +110,8 @@ class ClusterContinuousEmbedder(nn.Module): if drop_ids is not None: embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) + # print(labels[~drop_ids].shape) + # torch.Size([1200]) embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) embeddings[drop_ids] += self.embedding_drop.weight[0] else: diff --git a/graph_dit/models/transformer.py b/graph_dit/models/transformer.py index e9b8bfa..2ff7477 100644 --- a/graph_dit/models/transformer.py +++ b/graph_dit/models/transformer.py @@ -17,20 +17,22 @@ class Denoiser(nn.Module): num_heads=16, mlp_ratio=4.0, drop_condition=0.1, - Xdim=118, - Edim=5, - ydim=3, + Xdim=7, + Edim=2, + ydim=1, task_type='regression', ): super().__init__() + print(f"Denoiser, xdim: {Xdim}, edim: {Edim}, ydim: {ydim}, hidden_size: {hidden_size}, depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, drop_condition: {drop_condition}") self.num_heads = num_heads self.ydim = ydim self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) self.t_embedder = TimestepEmbedder(hidden_size) + # self.y_embedding_list = torch.nn.ModuleList() - self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) + self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) for i in range(ydim - 2): if task_type == 'regression': self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) @@ -88,6 +90,8 @@ class Denoiser(nn.Module): # print("Denoiser Forward") # print(x.shape, e.shape, y.shape, t.shape, unconditioned) + # torch.Size([1200, 8, 7]) torch.Size([1200, 8, 8, 2]) torch.Size([1200, 2]) torch.Size([1200, 1]) False + # print(y) force_drop_id = torch.zeros_like(y.sum(-1)) # drop the nan values force_drop_id[torch.isnan(y.sum(-1))] = 1 @@ -109,11 +113,12 @@ class Denoiser(nn.Module): c1 = self.t_embedder(t) # print("C1 after t_embedder") # print(c1.shape) - for i in range(1, self.ydim): - if i == 1: - c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) - else: - c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) + c2 = self.y_embedding_list[0](y[:,0].unsqueeze(-1), self.training, force_drop_id, t) + # for i in range(1, self.ydim): + # if i == 1: + # c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) + # else: + # c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) # print("C2 after y_embedding_list") # print(c2.shape) # print("C1 + C2")