need to update the model

This commit is contained in:
mhz
2024-09-10 16:57:42 +02:00
parent 97fbdf91c7
commit 0c60171c71
2 changed files with 23 additions and 7 deletions

View File

@@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
total_log_probs = torch.zeros([1000,10], device=self.device)
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):