try to update reward func
This commit is contained in:
@@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
|
||||
Reference in New Issue
Block a user