add some shape commits
This commit is contained in:
		| @@ -78,8 +78,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                                                               timesteps=cfg.model.diffusion_steps) |                                                               timesteps=cfg.model.diffusion_steps) | ||||||
|  |  | ||||||
|  |  | ||||||
|         print("__init__") |         # print("__init__") | ||||||
|         print("dataset_info.node_types", self.dataset_info.node_types) |         # print("dataset_info.node_types", self.dataset_info.node_types) | ||||||
|         # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) |         # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) | ||||||
|         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) |         x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) | ||||||
|          |          | ||||||
| @@ -123,8 +123,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         return pred |         return pred | ||||||
|          |          | ||||||
|     def training_step(self, data, i): |     def training_step(self, data, i): | ||||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] |         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||||
|         dense_data = dense_data.mask(node_mask) |         dense_data = dense_data.mask(node_mask) | ||||||
| @@ -138,6 +138,9 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, |         self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||||
|                         log=i % self.log_every_steps == 0) |                         log=i % self.log_every_steps == 0) | ||||||
|         self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) |         self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||||
|  |         print(f"training loss: {loss}") | ||||||
|  |         with open("training-loss.csv", "a") as f: | ||||||
|  |             f.write(f"{loss}, {i}\n") | ||||||
|         return {'loss': loss} |         return {'loss': loss} | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -150,7 +153,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|     def on_fit_start(self) -> None: |     def on_fit_start(self) -> None: | ||||||
|         self.train_iterations = self.trainer.datamodule.training_iterations |         self.train_iterations = self.trainer.datamodule.training_iterations | ||||||
|         print('on fit train iteration:', self.train_iterations) |         print('on fit train iteration:', self.train_iterations) | ||||||
|         print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) |         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||||
|  |  | ||||||
|     def on_train_epoch_start(self) -> None: |     def on_train_epoch_start(self) -> None: | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
| @@ -160,10 +163,12 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         self.train_metrics.reset() |         self.train_metrics.reset() | ||||||
|  |  | ||||||
|     def on_train_epoch_end(self) -> None: |     def on_train_epoch_end(self) -> None: | ||||||
|  |  | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|             log = True |             log = True | ||||||
|         else: |         else: | ||||||
|             log = False |             log = False | ||||||
|  |         log = True | ||||||
|         self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) |         self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) | ||||||
|         self.train_metrics.log_epoch_metrics(self.current_epoch, log) |         self.train_metrics.log_epoch_metrics(self.current_epoch, log) | ||||||
|  |  | ||||||
| @@ -178,22 +183,31 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def validation_step(self, data, i): |     def validation_step(self, data, i): | ||||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] |         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||||
|         dense_data = dense_data.mask(node_mask, collapse=True) |         dense_data = dense_data.mask(node_mask, collapse=False) | ||||||
|         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) |         noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||||
|         pred = self.forward(noisy_data) |         pred = self.forward(noisy_data) | ||||||
|         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) |         nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) | ||||||
|         self.val_y_collection.append(data.y) |         self.val_y_collection.append(data.y) | ||||||
|         self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) |         self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) | ||||||
|  |         print(f'validation loss: {nll}, epoch: {self.current_epoch}') | ||||||
|         return {'loss': nll} |         return {'loss': nll} | ||||||
|  |  | ||||||
|     def on_validation_epoch_end(self) -> None: |     def on_validation_epoch_end(self) -> None: | ||||||
|         metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, |         metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, | ||||||
|  |          | ||||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] |                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||||
|          |          | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|  |             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||||
|  |                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||||
|  |         with open("validation-metrics.csv", "a") as f: | ||||||
|  |             # save the metrics as csv file | ||||||
|  |             f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") | ||||||
|  |              | ||||||
|  |  | ||||||
|         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", |         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||||
|               f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) |               f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||||
|  |  | ||||||
| @@ -241,15 +255,15 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                 samples_left_to_generate -= to_generate |                 samples_left_to_generate -= to_generate | ||||||
|                 chains_left_to_save -= chains_save |                 chains_left_to_save -= chains_save | ||||||
|  |  | ||||||
|             # print(f"Computing sampling metrics", ' ...') |             print(f"Computing sampling metrics", ' ...') | ||||||
|             # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) |             valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) | ||||||
|             # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') |             print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') | ||||||
|  |  | ||||||
|             current_path = os.getcwd() |             # current_path = os.getcwd() | ||||||
|             result_path = os.path.join(current_path, |             # result_path = os.path.join(current_path, | ||||||
|                                        f'graphs/{self.name}/epoch{self.current_epoch}_b0/') |                                     #    f'graphs/{self.name}/epoch{self.current_epoch}_b0/') | ||||||
|             # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) |             # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) | ||||||
|             # self.sampling_metrics.reset() |             self.sampling_metrics.reset() | ||||||
|  |  | ||||||
|     def on_test_epoch_start(self) -> None: |     def on_test_epoch_start(self) -> None: | ||||||
|         print("Starting test...") |         print("Starting test...") | ||||||
| @@ -262,8 +276,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|      |      | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def test_step(self, data, i): |     def test_step(self, data, i): | ||||||
|         data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] |         data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] | ||||||
|         data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() |         data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) |         dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) | ||||||
|         dense_data = dense_data.mask(node_mask) |         dense_data = dense_data.mask(node_mask) | ||||||
| @@ -277,6 +291,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         """ Measure likelihood on a test set and compute stability metrics. """ |         """ Measure likelihood on a test set and compute stability metrics. """ | ||||||
|         metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), |         metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), | ||||||
|                    self.test_X_logp.compute(), self.test_E_logp.compute()] |                    self.test_X_logp.compute(), self.test_E_logp.compute()] | ||||||
|  |         with open("test-metrics.csv", "a") as f: | ||||||
|  |             f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") | ||||||
|  |  | ||||||
|         print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", |         print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", | ||||||
|               f"Test Edge type KL: {metrics[2] :.2f}") |               f"Test Edge type KL: {metrics[2] :.2f}") | ||||||
| @@ -433,10 +449,12 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         # Sample a timestep t. |         # Sample a timestep t. | ||||||
|         # When evaluating, the loss for t=0 is computed separately |         # When evaluating, the loss for t=0 is computed separately | ||||||
|  |         # print(f"apply_noise X shape: {X.shape}, E shape: {E.shape}, y shape: {y.shape}, node_mask shape: {node_mask.shape}") | ||||||
|         lowest_t = 0 if self.training else 1 |         lowest_t = 0 if self.training else 1 | ||||||
|         t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1) |         t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1) | ||||||
|         s_int = t_int - 1 |         s_int = t_int - 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|         t_float = t_int / self.T |         t_float = t_int / self.T | ||||||
|         s_float = s_int / self.T |         s_float = s_int / self.T | ||||||
|  |  | ||||||
| @@ -444,10 +462,23 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1) |         beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1) | ||||||
|         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1) |         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1) | ||||||
|         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1) |         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1) | ||||||
|  |         # print(f"alpha_s_bar: {alpha_s_bar.shape}, alpha_t_bar: {alpha_t_bar.shape}") | ||||||
|  |  | ||||||
|         Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)  # (bs, dx_in, dx_out), (bs, de_in, de_out) |         Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)  # (bs, dx_in, dx_out), (bs, de_in, de_out) | ||||||
|  |         # print(f"X shape: {X.shape}, E shape: {E.shape}, node_mask shape: {node_mask.shape}") | ||||||
|  |         # print(f"Qtb shape: {Qtb.X.shape}") | ||||||
|  |         """ | ||||||
|  |         X shape: torch.Size([1200, 8]),  | ||||||
|  |         E shape: torch.Size([1200, 8, 8]),  | ||||||
|  |         y shape: torch.Size([1200, 1]),  | ||||||
|  |         node_mask shape: torch.Size([1200, 8]) | ||||||
|  |         alpha_s_bar: torch.Size([1200, 1]), alpha_t_bar: torch.Size([1200, 1]) | ||||||
|  |         """ | ||||||
|          |          | ||||||
|  |         # print(X.shape) | ||||||
|         bs, n, d = X.shape |         bs, n, d = X.shape | ||||||
|  |         E = E[..., :2] | ||||||
|  |         # bs, n = X.shape | ||||||
|         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) |         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | ||||||
|         prob_all = X_all @ Qtb.X |         prob_all = X_all @ Qtb.X | ||||||
|         probX = prob_all[:, :, :self.Xdim_output] |         probX = prob_all[:, :, :self.Xdim_output] | ||||||
| @@ -457,6 +488,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) |         X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | ||||||
|         E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) |         E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | ||||||
|  |         # print(f"X.shape: {X.shape}, X_t shape: {X_t.shape}, E.shape: {E.shape},  E_t shape: {E_t.shape}") | ||||||
|         assert (X.shape == X_t.shape) and (E.shape == E_t.shape) |         assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | ||||||
|  |  | ||||||
|         y_t = y |         y_t = y | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user