try to deploy PPO policy
This commit is contained in:
@@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
total_log_probs = torch.zeros(batch_size, device=self.device)
|
||||
total_log_probs = torch.zeros([1000,10], device=self.device)
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
for s_int in reversed(range(0, self.T)):
|
||||
@@ -613,6 +613,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
# Sample z_s
|
||||
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
|
||||
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
||||
print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}')
|
||||
print(f'log_probs shape: {log_probs.shape}')
|
||||
total_log_probs += log_probs
|
||||
|
||||
# Sample
|
||||
@@ -688,8 +690,9 @@ class Graph_DiT(pl.LightningModule):
|
||||
log_prob_X = log_prob_X.sum(dim=-1)
|
||||
log_prob_E = log_prob_E.sum(dim=(1, 2))
|
||||
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
|
||||
log_probs = log_prob_E + log_prob_X
|
||||
|
||||
# log_probs = log_prob_E + log_prob_X
|
||||
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1) # (batch_size, 2)
|
||||
print(f'log_probs shape: {log_probs.shape}')
|
||||
### Guidance
|
||||
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
|
||||
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
|
||||
|
||||
Reference in New Issue
Block a user