now we add reward wait to test
This commit is contained in:
parent
1ad520d248
commit
c867aef5a6
@ -312,6 +312,7 @@ def test(cfg: DictConfig):
|
||||
# reward = 1.0
|
||||
# rewards.append(reward)
|
||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||
old_log_probs = None
|
||||
while samples_left_to_generate > 0:
|
||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||
@ -327,6 +328,15 @@ def test(cfg: DictConfig):
|
||||
samples = samples + cur_sample
|
||||
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||
if old_log_probs is None:
|
||||
old_log_probs = log_probs.clone()
|
||||
ratio = torch.exp(log_probs - old_log_probs)
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param)
|
||||
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||
|
Loading…
Reference in New Issue
Block a user