From fcdd8efc4f4ccdeaad6591dfdcdbc3161cd0d8d3 Mon Sep 17 00:00:00 2001 From: mhz Date: Tue, 16 Jul 2024 13:27:44 +0200 Subject: [PATCH] find the guidance part --- graph_dit/diffusion_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 4ba6dba..a5eab3d 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -609,6 +609,7 @@ class Graph_DiT(pl.LightningModule): Qt = self.transition_model.get_Qt(beta_t, self.device) Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) + # p(x_0|x_t) p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X,