train phase done

This commit is contained in:
mhz 2024-09-08 21:09:41 +02:00
parent 11d9697e06
commit 0c4b597dd2

View File

@ -158,9 +158,9 @@ def test(cfg: DictConfig):
total_limit=cfg.general.number_checkpoint_limit, total_limit=cfg.general.number_checkpoint_limit,
) )
accelerator = Accelerator( accelerator = Accelerator(
mixed_precision=cfg.mixed_precision, mixed_precision='no',
project_config=accelerator_config, project_config=accelerator_config,
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.n_epochs, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
) )
set_seed(cfg.train.seed, device_specific=True) set_seed(cfg.train.seed, device_specific=True)
@ -203,28 +203,35 @@ def test(cfg: DictConfig):
# graph_dit_model.model.requires_grad_(True) # graph_dit_model.model.requires_grad_(True)
import torch.nn.functional as F import torch.nn.functional as F
optimizer = graph_dit_model.configure_optimizers() optimizer = graph_dit_model.configure_optimizers()
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
# start training # start training
for epoch in range(cfg.train.n_epochs): for epoch in range(cfg.train.n_epochs):
graph_dit_model.train() # 设置模型为训练模式 graph_dit_model.train() # 设置模型为训练模式
for batch_data in datamodule.train_dataloader: # 从数据加载器中获取一个批次的数据 print(f"Epoch {epoch}", end="\n")
data_x = F.one_hot(batch_data.x, num_classes=12).float()[:, graph_dit_model.active_index] # 节点特征 for data in train_dataloader: # 从数据加载器中获取一个批次的数据
data_edge_attr = F.one_hot(batch_data.edge_attr, num_classes=2).float() # 边特征 data.to(accelerator.device)
data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index]
# 转换为 dense 格式并传递给 Graph_DiT data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, batch_data.edge_index, data_edge_attr, batch_data.batch, graph_dit_model.max_n_nodes) dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes)
dense_data = dense_data.mask(node_mask) dense_data = dense_data.mask(node_mask)
X, E = dense_data.X, dense_data.E
X, E = dense_data.X, dense_data.E # 节点特征和边特征 noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask)
y = batch_data.y # 标签 pred = graph_dit_model.forward(noisy_data)
loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
# 前向传播和损失计算 true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
pred = graph_dit_model(dense_data) # 传入 Graph_DiT 模型 log=epoch % graph_dit_model.log_every_steps == 0)
loss = graph_dit_model.train_loss(pred, X, E, y, node_mask) # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
# 优化步骤 log=epoch % graph_dit_model.log_every_steps == 0)
optimizer.zero_grad() graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
loss.backward() print(f"training loss: {loss}")
with open("training-loss.csv", "a") as f:
f.write(f"{loss}, {epoch}\n")
accelerator.backward(loss)
optimizer.step() optimizer.step()
optimizer.zero_grad()
# return {'loss': loss}
# start sampling # start sampling
@ -248,6 +255,9 @@ def test(cfg: DictConfig):
) )
samples.append(samples_batch) samples.append(samples_batch)
# save samples
print("Samples:")
print(samples)
# trainer = Trainer( # trainer = Trainer(
# gradient_clip_val=cfg.train.clip_grad, # gradient_clip_val=cfg.train.clip_grad,