try to get the original perf

This commit is contained in:
mhz
2024-09-16 22:45:12 +02:00
parent c867aef5a6
commit 91d4e3c7ad
2 changed files with 105 additions and 189 deletions

View File

@@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule):
# print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
def on_train_epoch_start(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
# print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs))
self.start_epoch_time = time.time()
self.train_loss.reset()
self.train_metrics.reset()
def on_train_epoch_end(self) -> None:
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
log = True
else:
log = False
@@ -601,8 +604,8 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
# total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):