now we add reward wait to test

This commit is contained in:
mhz 2024-09-15 22:21:09 +02:00
parent 1ad520d248
commit c867aef5a6

View File

@ -312,6 +312,7 @@ def test(cfg: DictConfig):
# reward = 1.0 # reward = 1.0
# rewards.append(reward) # rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
old_log_probs = None
while samples_left_to_generate > 0: while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/' print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
@ -327,6 +328,15 @@ def test(cfg: DictConfig):
samples = samples + cur_sample samples = samples + cur_sample
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
if old_log_probs is None:
old_log_probs = log_probs.clone()
ratio = torch.exp(log_probs - old_log_probs)
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
samples_with_log_probs.append((cur_sample, log_probs, reward)) samples_with_log_probs.append((cur_sample, log_probs, reward))