try to deploy PPO policy

This commit is contained in:
mhz
2024-09-09 23:50:10 +02:00
parent 297261d666
commit 97fbdf91c7
2 changed files with 58 additions and 7 deletions

View File

@@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all()
total_log_probs = torch.zeros(batch_size, device=self.device)
total_log_probs = torch.zeros([1000,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)):
@@ -613,6 +613,8 @@ class Graph_DiT(pl.LightningModule):
# Sample z_s
sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}')
print(f'log_probs shape: {log_probs.shape}')
total_log_probs += log_probs
# Sample
@@ -688,8 +690,9 @@ class Graph_DiT(pl.LightningModule):
log_prob_X = log_prob_X.sum(dim=-1)
log_prob_E = log_prob_E.sum(dim=(1, 2))
print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}')
log_probs = log_prob_E + log_prob_X
# log_probs = log_prob_E + log_prob_X
log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1) # (batch_size, 2)
print(f'log_probs shape: {log_probs.shape}')
### Guidance
if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)