need to update the model

This commit is contained in:
mhz
2024-09-12 23:40:42 +02:00
parent 0c60171c71
commit 2ac17caa3c
2 changed files with 121 additions and 48 deletions

View File

@@ -239,8 +239,8 @@ class Graph_DiT(pl.LightningModule):
self.val_X_logp.compute(), self.val_E_logp.compute()]
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
with open("validation-metrics.csv", "a") as f:
# save the metrics as csv file