train phase done
This commit is contained in:
parent
11d9697e06
commit
0c4b597dd2
@ -158,9 +158,9 @@ def test(cfg: DictConfig):
|
|||||||
total_limit=cfg.general.number_checkpoint_limit,
|
total_limit=cfg.general.number_checkpoint_limit,
|
||||||
)
|
)
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
mixed_precision=cfg.mixed_precision,
|
mixed_precision='no',
|
||||||
project_config=accelerator_config,
|
project_config=accelerator_config,
|
||||||
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.n_epochs,
|
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
|
||||||
)
|
)
|
||||||
set_seed(cfg.train.seed, device_specific=True)
|
set_seed(cfg.train.seed, device_specific=True)
|
||||||
|
|
||||||
@ -203,28 +203,35 @@ def test(cfg: DictConfig):
|
|||||||
# graph_dit_model.model.requires_grad_(True)
|
# graph_dit_model.model.requires_grad_(True)
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
optimizer = graph_dit_model.configure_optimizers()
|
optimizer = graph_dit_model.configure_optimizers()
|
||||||
|
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
|
||||||
|
optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model)
|
||||||
# start training
|
# start training
|
||||||
for epoch in range(cfg.train.n_epochs):
|
for epoch in range(cfg.train.n_epochs):
|
||||||
graph_dit_model.train() # 设置模型为训练模式
|
graph_dit_model.train() # 设置模型为训练模式
|
||||||
for batch_data in datamodule.train_dataloader: # 从数据加载器中获取一个批次的数据
|
print(f"Epoch {epoch}", end="\n")
|
||||||
data_x = F.one_hot(batch_data.x, num_classes=12).float()[:, graph_dit_model.active_index] # 节点特征
|
for data in train_dataloader: # 从数据加载器中获取一个批次的数据
|
||||||
data_edge_attr = F.one_hot(batch_data.edge_attr, num_classes=2).float() # 边特征
|
data.to(accelerator.device)
|
||||||
|
data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index]
|
||||||
# 转换为 dense 格式并传递给 Graph_DiT
|
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
|
||||||
dense_data, node_mask = utils.to_dense(data_x, batch_data.edge_index, data_edge_attr, batch_data.batch, graph_dit_model.max_n_nodes)
|
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes)
|
||||||
dense_data = dense_data.mask(node_mask)
|
dense_data = dense_data.mask(node_mask)
|
||||||
|
X, E = dense_data.X, dense_data.E
|
||||||
X, E = dense_data.X, dense_data.E # 节点特征和边特征
|
noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask)
|
||||||
y = batch_data.y # 标签
|
pred = graph_dit_model.forward(noisy_data)
|
||||||
|
loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y,
|
||||||
# 前向传播和损失计算
|
true_X=X, true_E=E, true_y=data.y, node_mask=node_mask,
|
||||||
pred = graph_dit_model(dense_data) # 传入 Graph_DiT 模型
|
log=epoch % graph_dit_model.log_every_steps == 0)
|
||||||
loss = graph_dit_model.train_loss(pred, X, E, y, node_mask)
|
# print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}')
|
||||||
|
graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E,
|
||||||
# 优化步骤
|
log=epoch % graph_dit_model.log_every_steps == 0)
|
||||||
optimizer.zero_grad()
|
graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True)
|
||||||
loss.backward()
|
print(f"training loss: {loss}")
|
||||||
|
with open("training-loss.csv", "a") as f:
|
||||||
|
f.write(f"{loss}, {epoch}\n")
|
||||||
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# return {'loss': loss}
|
||||||
|
|
||||||
# start sampling
|
# start sampling
|
||||||
|
|
||||||
@ -248,6 +255,9 @@ def test(cfg: DictConfig):
|
|||||||
)
|
)
|
||||||
samples.append(samples_batch)
|
samples.append(samples_batch)
|
||||||
|
|
||||||
|
# save samples
|
||||||
|
print("Samples:")
|
||||||
|
print(samples)
|
||||||
|
|
||||||
# trainer = Trainer(
|
# trainer = Trainer(
|
||||||
# gradient_clip_val=cfg.train.clip_grad,
|
# gradient_clip_val=cfg.train.clip_grad,
|
||||||
|
Loading…
Reference in New Issue
Block a user