try to get the original perf
This commit is contained in:
@@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule):
|
||||
# print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim))
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
|
||||
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
# print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs))
|
||||
print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs))
|
||||
self.start_epoch_time = time.time()
|
||||
self.train_loss.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def on_train_epoch_end(self) -> None:
|
||||
|
||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||
log = True
|
||||
else:
|
||||
log = False
|
||||
@@ -601,8 +604,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||
# total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device)
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
for s_int in reversed(range(0, self.T)):
|
||||
|
||||
Reference in New Issue
Block a user