train phase done
This commit is contained in:
parent
11d9697e06
commit
0c4b597dd2
@ -158,9 +158,9 @@ def test(cfg: DictConfig):
|
||||
total_limit=cfg.general.number_checkpoint_limit,
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
mixed_precision='no',
|
||||
project_config=accelerator_config,
|
||||
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.n_epochs,
|
||||
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
|
||||
)
|
||||
set_seed(cfg.train.seed, device_specific=True)
|
||||
|
||||
@ -203,28 +203,35 @@ def test(cfg: DictConfig):
|
||||
# graph_dit_model.model.requires_grad_(True)
|
||||
import torch.nn.functional as F
|
||||
optimizer = graph_dit_model.configure_optimizers()
|
||||
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
|
||||
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
|
||||
# start training
|
||||
for epoch in range(cfg.train.n_epochs):
|
||||
graph_dit_model.train() # 设置模型为训练模式
|
||||
for batch_data in datamodule.train_dataloader: # 从数据加载器中获取一个批次的数据
|
||||
data_x = F.one_hot(batch_data.x, num_classes=12).float()[:, graph_dit_model.active_index] # 节点特征
|
||||
data_edge_attr = F.one_hot(batch_data.edge_attr, num_classes=2).float() # 边特征
|
||||
|
||||
# 转换为 dense 格式并传递给 Graph_DiT
|
||||
dense_data, node_mask = utils.to_dense(data_x, batch_data.edge_index, data_edge_attr, batch_data.batch, graph_dit_model.max_n_nodes)
|
||||
print(f"Epoch {epoch}", end="\n")
|
||||
for data in train_dataloader: # 从数据加载器中获取一个批次的数据
|
||||
data.to(accelerator.device)
|
||||
data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index]
|
||||
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes)
|
||||
dense_data = dense_data.mask(node_mask)
|
||||
|
||||
X, E = dense_data.X, dense_data.E # 节点特征和边特征
|
||||
y = batch_data.y # 标签
|
||||
|
||||
# 前向传播和损失计算
|
||||
pred = graph_dit_model(dense_data) # 传入 Graph_DiT 模型
|
||||
loss = graph_dit_model.train_loss(pred, X, E, y, node_mask)
|
||||
|
||||
# 优化步骤
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
X, E = dense_data.X, dense_data.E
|
||||
noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask)
|
||||
pred = graph_dit_model.forward(noisy_data)
|
||||
loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
||||
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
||||
log=epoch % graph_dit_model.log_every_steps == 0)
|
||||
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
|
||||
graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||
log=epoch % graph_dit_model.log_every_steps == 0)
|
||||
graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||
print(f"training loss: {loss}")
|
||||
with open("training-loss.csv", "a") as f:
|
||||
f.write(f"{loss}, {epoch}\n")
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# return {'loss': loss}
|
||||
|
||||
# start sampling
|
||||
|
||||
@ -248,6 +255,9 @@ def test(cfg: DictConfig):
|
||||
)
|
||||
samples.append(samples_batch)
|
||||
|
||||
# save samples
|
||||
print("Samples:")
|
||||
print(samples)
|
||||
|
||||
# trainer = Trainer(
|
||||
# gradient_clip_val=cfg.train.clip_grad,
|
||||
|
Loading…
Reference in New Issue
Block a user