now we add reward wait to test
This commit is contained in:
parent
1ad520d248
commit
c867aef5a6
@ -312,6 +312,7 @@ def test(cfg: DictConfig):
|
|||||||
# reward = 1.0
|
# reward = 1.0
|
||||||
# rewards.append(reward)
|
# rewards.append(reward)
|
||||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
|
old_log_probs = None
|
||||||
while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
@ -327,6 +328,15 @@ def test(cfg: DictConfig):
|
|||||||
samples = samples + cur_sample
|
samples = samples + cur_sample
|
||||||
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||||
|
if old_log_probs is None:
|
||||||
|
old_log_probs = log_probs.clone()
|
||||||
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
unclipped_loss = -advantages * ratio
|
||||||
|
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param)
|
||||||
|
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||||
|
Loading…
Reference in New Issue
Block a user