need to update the model
This commit is contained in:
@@ -239,8 +239,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
||||
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||
with open("validation-metrics.csv", "a") as f:
|
||||
# save the metrics as csv file
|
||||
|
||||
Reference in New Issue
Block a user