From 0c60171c71a1b6da4427fead00456d429ab44908 Mon Sep 17 00:00:00 2001 From: mhz Date: Tue, 10 Sep 2024 16:57:42 +0200 Subject: [PATCH] need to update the model --- graph_dit/diffusion_model.py | 2 +- graph_dit/main.py | 28 ++++++++++++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 6c7c5ee..340fb3e 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule): assert (E == torch.transpose(E, 1, 2)).all() - total_log_probs = torch.zeros([1000,10], device=self.device) + total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. for s_int in reversed(range(0, self.T)): diff --git a/graph_dit/main.py b/graph_dit/main.py index 1c35d92..12677d5 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -273,12 +273,14 @@ def test(cfg: DictConfig): ratio = cfg.general.final_model_samples_to_generate // num_examples test_y_collection = test_y_collection.repeat(ratio+1, 1) num_examples = test_y_collection.size(0) + + # Normal reward function def graph_reward_fn(graphs, true_graphs=None, device=None): rewards = [] for graph in graphs: reward = 1.0 rewards.append(reward) - return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device) + return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) @@ -317,19 +319,33 @@ def test(cfg: DictConfig): samples, log_probs, rewards = samples_with_log_probs[perm] samples = list(samples) log_probs = list(log_probs) + for i in range(len(log_probs)): + log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) print(f'log_probs: {log_probs[:5]}') - print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000]) + print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) rewards = list(rewards) + log_probs = torch.cat(log_probs, dim=0) + print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) + old_log_probs = log_probs.clone() + # multi metrics range + # reward hacking hiking for inner_epoch in range(cfg.train.n_epochs): - # print(f'rewards: {rewards[0].shape}') # torch.Size([1000]) - rewards = torch.cat(rewards, dim=0) - print(f'rewards: {rewards.shape}') + # print(f'rewards: {rewards.shape}') # torch.Size([1000]) + print(f'rewards: {rewards[:5]}') + print(f'len rewards: {len(rewards)}') + print(f'type rewards: {type(rewards)}') + if len(rewards) > 1 and isinstance(rewards, list): + rewards = torch.cat(rewards, dim=0) + elif len(rewards) == 1 and isinstance(rewards, list): + rewards = rewards[0] + # print(f'rewards: {rewards.shape}') advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) - old_log_probs = log_probs.copy() + print(f'advantages: {advantages.shape}') with accelerator.accumulate(graph_dit_model): ratio = torch.exp(log_probs - old_log_probs) unclipped_loss = -advantages * ratio + # z-score normalization clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param)