try to update reward func

This commit is contained in:
mhz
2024-09-14 23:56:36 +02:00
parent 2ac17caa3c
commit 94fe13756f
2 changed files with 96 additions and 75 deletions

View File

@@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.