diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 656a695..4ba6dba 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -78,8 +78,8 @@ class Graph_DiT(pl.LightningModule): timesteps=cfg.model.diffusion_steps) - print("__init__") - print("dataset_info.node_types", self.dataset_info.node_types) + # print("__init__") + # print("dataset_info.node_types", self.dataset_info.node_types) # dataset_info.node_types tensor([7.4826e-01, 2.6870e-02, 9.3930e-02, 4.4959e-02, 5.2982e-03, 7.5689e-04, 5.3739e-03, 1.5138e-03, 7.5689e-05, 4.3143e-03, 6.8650e-02]) x_marginals = self.dataset_info.node_types.float() / torch.sum(self.dataset_info.node_types.float()) @@ -123,8 +123,8 @@ class Graph_DiT(pl.LightningModule): return pred def training_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data = dense_data.mask(node_mask) @@ -138,6 +138,9 @@ class Graph_DiT(pl.LightningModule): self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, log=i % self.log_every_steps == 0) self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) + print(f"training loss: {loss}") + with open("training-loss.csv", "a") as f: + f.write(f"{loss}, {i}\n") return {'loss': loss} @@ -150,7 +153,7 @@ class Graph_DiT(pl.LightningModule): def on_fit_start(self) -> None: self.train_iterations = self.trainer.datamodule.training_iterations print('on fit train iteration:', self.train_iterations) - print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) + # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) def on_train_epoch_start(self) -> None: if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: @@ -160,10 +163,12 @@ class Graph_DiT(pl.LightningModule): self.train_metrics.reset() def on_train_epoch_end(self) -> None: + if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: log = True else: log = False + log = True self.train_loss.log_epoch_metrics(self.current_epoch, self.start_epoch_time, log) self.train_metrics.log_epoch_metrics(self.current_epoch, log) @@ -178,24 +183,33 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def validation_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=10).float() + data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) - dense_data = dense_data.mask(node_mask, collapse=True) + dense_data = dense_data.mask(node_mask, collapse=False) noisy_data = self.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) pred = self.forward(noisy_data) nll = self.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=False) self.val_y_collection.append(data.y) self.log(f'valid_nll', nll, batch_size=data.x.size(0), sync_dist=True) + print(f'validation loss: {nll}, epoch: {self.current_epoch}') return {'loss': nll} def on_validation_epoch_end(self) -> None: metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T, + self.val_X_logp.compute(), self.val_E_logp.compute()] if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll)) + with open("validation-metrics.csv", "a") as f: + # save the metrics as csv file + f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") + + + print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", + f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll)) # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback self.log("val/NLL", metrics[0], sync_dist=True) @@ -241,15 +255,15 @@ class Graph_DiT(pl.LightningModule): samples_left_to_generate -= to_generate chains_left_to_save -= chains_save - # print(f"Computing sampling metrics", ' ...') - # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) - # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') + print(f"Computing sampling metrics", ' ...') + valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) + print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') - current_path = os.getcwd() - result_path = os.path.join(current_path, - f'graphs/{self.name}/epoch{self.current_epoch}_b0/') + # current_path = os.getcwd() + # result_path = os.path.join(current_path, + # f'graphs/{self.name}/epoch{self.current_epoch}_b0/') # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) - # self.sampling_metrics.reset() + self.sampling_metrics.reset() def on_test_epoch_start(self) -> None: print("Starting test...") @@ -262,8 +276,8 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def test_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=118).float()[:, self.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=5).float() + data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data = dense_data.mask(node_mask) @@ -277,6 +291,8 @@ class Graph_DiT(pl.LightningModule): """ Measure likelihood on a test set and compute stability metrics. """ metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(), self.test_X_logp.compute(), self.test_E_logp.compute()] + with open("test-metrics.csv", "a") as f: + f.write(f"{self.current_epoch}, {metrics[0]}, {metrics[1]}, {metrics[2]}, {metrics[3]}, {metrics[4]}\n") print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ", f"Test Edge type KL: {metrics[2] :.2f}") @@ -433,10 +449,12 @@ class Graph_DiT(pl.LightningModule): # Sample a timestep t. # When evaluating, the loss for t=0 is computed separately + # print(f"apply_noise X shape: {X.shape}, E shape: {E.shape}, y shape: {y.shape}, node_mask shape: {node_mask.shape}") lowest_t = 0 if self.training else 1 t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1) s_int = t_int - 1 + t_float = t_int / self.T s_float = s_int / self.T @@ -444,10 +462,23 @@ class Graph_DiT(pl.LightningModule): beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) + # print(f"alpha_s_bar: {alpha_s_bar.shape}, alpha_t_bar: {alpha_t_bar.shape}") Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) # (bs, dx_in, dx_out), (bs, de_in, de_out) + # print(f"X shape: {X.shape}, E shape: {E.shape}, node_mask shape: {node_mask.shape}") + # print(f"Qtb shape: {Qtb.X.shape}") + """ + X shape: torch.Size([1200, 8]), + E shape: torch.Size([1200, 8, 8]), + y shape: torch.Size([1200, 1]), + node_mask shape: torch.Size([1200, 8]) + alpha_s_bar: torch.Size([1200, 1]), alpha_t_bar: torch.Size([1200, 1]) + """ + # print(X.shape) bs, n, d = X.shape + E = E[..., :2] + # bs, n = X.shape X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) prob_all = X_all @ Qtb.X probX = prob_all[:, :, :self.Xdim_output] @@ -457,6 +488,7 @@ class Graph_DiT(pl.LightningModule): X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) + # print(f"X.shape: {X.shape}, X_t shape: {X_t.shape}, E.shape: {E.shape}, E_t shape: {E_t.shape}") assert (X.shape == X_t.shape) and (E.shape == E_t.shape) y_t = y