find the guidance part
This commit is contained in:
		| @@ -609,6 +609,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|             Qt = self.transition_model.get_Qt(beta_t, self.device) | ||||
|  | ||||
|             Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) | ||||
|             # p(x_0|x_t) | ||||
|             p_s_and_t_given_0 = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=Xt_all, | ||||
|                                                                                             Qt=Qt.X, | ||||
|                                                                                             Qsb=Qsb.X, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user