Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
dcfefb91a3 | |||
8bbadce19c |
@ -716,7 +716,6 @@ class DataInfos(AbstractDatasetInfos):
|
||||
graphs.append((adj_matrix, ops))
|
||||
|
||||
meta_dict = graphs_to_json(graphs, 'nasbench-201')
|
||||
|
||||
self.base_path = base_path
|
||||
self.active_atoms = meta_dict['active_atoms']
|
||||
self.max_n_nodes = meta_dict['max_node']
|
||||
|
@ -179,9 +179,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
@torch.no_grad()
|
||||
def validation_step(self, data, i):
|
||||
data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float()
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
dense_data = dense_data.mask(node_mask, collapse=False)
|
||||
noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
|
||||
pred = self.forward(noisy_data)
|
||||
nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False)
|
||||
@ -444,9 +444,11 @@ class Graph_DiT(pl.LightningModule):
|
||||
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
|
||||
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
|
||||
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
|
||||
print(f"alpha_t_bar.shape {alpha_t_bar.shape}")
|
||||
|
||||
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out)
|
||||
|
||||
print(f"E.shape {E.shape}")
|
||||
print(f"X.shape {X.shape}")
|
||||
bs, n, d = X.shape
|
||||
X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
|
||||
prob_all = X_all @ Qtb.X
|
||||
|
Loading…
Reference in New Issue
Block a user