diff --git a/graph_dit/main.py b/graph_dit/main.py index 5e16301..d0ed8df 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -158,9 +158,9 @@ def test(cfg: DictConfig): total_limit=cfg.general.number_checkpoint_limit, ) accelerator = Accelerator( - mixed_precision=cfg.mixed_precision, + mixed_precision='no', project_config=accelerator_config, - gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.n_epochs, + gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, ) set_seed(cfg.train.seed, device_specific=True) @@ -203,28 +203,35 @@ def test(cfg: DictConfig): # graph_dit_model.model.requires_grad_(True) import torch.nn.functional as F optimizer = graph_dit_model.configure_optimizers() + train_dataloader = accelerator.prepare(datamodule.train_dataloader()) + optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) # start training for epoch in range(cfg.train.n_epochs): graph_dit_model.train() # 设置模型为训练模式 - for batch_data in datamodule.train_dataloader: # 从数据加载器中获取一个批次的数据 - data_x = F.one_hot(batch_data.x, num_classes=12).float()[:, graph_dit_model.active_index] # 节点特征 - data_edge_attr = F.one_hot(batch_data.edge_attr, num_classes=2).float() # 边特征 - - # 转换为 dense 格式并传递给 Graph_DiT - dense_data, node_mask = utils.to_dense(data_x, batch_data.edge_index, data_edge_attr, batch_data.batch, graph_dit_model.max_n_nodes) + print(f"Epoch {epoch}", end="\n") + for data in train_dataloader: # 从数据加载器中获取一个批次的数据 + data.to(accelerator.device) + data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] + data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() + dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) dense_data = dense_data.mask(node_mask) - - X, E = dense_data.X, dense_data.E # 节点特征和边特征 - y = batch_data.y # 标签 - - # 前向传播和损失计算 - pred = graph_dit_model(dense_data) # 传入 Graph_DiT 模型 - loss = graph_dit_model.train_loss(pred, X, E, y, node_mask) - - # 优化步骤 - optimizer.zero_grad() - loss.backward() + X, E = dense_data.X, dense_data.E + noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) + pred = graph_dit_model.forward(noisy_data) + loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, + true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, + log=epoch % graph_dit_model.log_every_steps == 0) + # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') + graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, + log=epoch % graph_dit_model.log_every_steps == 0) + graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) + print(f"training loss: {loss}") + with open("training-loss.csv", "a") as f: + f.write(f"{loss}, {epoch}\n") + accelerator.backward(loss) optimizer.step() + optimizer.zero_grad() + # return {'loss': loss} # start sampling @@ -248,6 +255,9 @@ def test(cfg: DictConfig): ) samples.append(samples_batch) + # save samples + print("Samples:") + print(samples) # trainer = Trainer( # gradient_clip_val=cfg.train.clip_grad,