find the guidance part
This commit is contained in:
parent
0b9da26eda
commit
fcdd8efc4f
@ -609,6 +609,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
Qt = self.transition_model.get_Qt(beta_t, self.device)
|
Qt = self.transition_model.get_Qt(beta_t, self.device)
|
||||||
|
|
||||||
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
|
Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
|
||||||
|
# p(x_0|x_t)
|
||||||
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
|
p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all,
|
||||||
Qt=Qt.X,
|
Qt=Qt.X,
|
||||||
Qsb=Qsb.X,
|
Qsb=Qsb.X,
|
||||||
|
Loading…
Reference in New Issue
Block a user