From c867aef5a69f77f5d437338179b21f29c4f728f4 Mon Sep 17 00:00:00 2001 From: mhz Date: Sun, 15 Sep 2024 22:21:09 +0200 Subject: [PATCH] now we add reward wait to test --- graph_dit/main.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/graph_dit/main.py b/graph_dit/main.py index 269c008..7023a24 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -312,6 +312,7 @@ def test(cfg: DictConfig): # reward = 1.0 # rewards.append(reward) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) + old_log_probs = None while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) @@ -327,6 +328,15 @@ def test(cfg: DictConfig): samples = samples + cur_sample reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + if old_log_probs is None: + old_log_probs = log_probs.clone() + ratio = torch.exp(log_probs - old_log_probs) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) + loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() samples_with_log_probs.append((cur_sample, log_probs, reward))