need to update the model

This commit is contained in:
mhz 2024-09-10 16:57:42 +02:00
parent 97fbdf91c7
commit 0c60171c71
2 changed files with 23 additions and 7 deletions

View File

@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all() assert (E == torch.transpose(E, 1, 2)).all()
total_log_probs = torch.zeros([1000,10], device=self.device) total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, self.T)): for s_int in reversed(range(0, self.T)):

View File

@ -273,12 +273,14 @@ def test(cfg: DictConfig):
ratio = cfg.general.final_model_samples_to_generate // num_examples ratio = cfg.general.final_model_samples_to_generate // num_examples
test_y_collection = test_y_collection.repeat(ratio+1, 1) test_y_collection = test_y_collection.repeat(ratio+1, 1)
num_examples = test_y_collection.size(0) num_examples = test_y_collection.size(0)
# Normal reward function
def graph_reward_fn(graphs, true_graphs=None, device=None): def graph_reward_fn(graphs, true_graphs=None, device=None):
rewards = [] rewards = []
for graph in graphs: for graph in graphs:
reward = 1.0 reward = 1.0
rewards.append(reward) rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
while samples_left_to_generate > 0: while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/' print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
@ -317,19 +319,33 @@ def test(cfg: DictConfig):
samples, log_probs, rewards = samples_with_log_probs[perm] samples, log_probs, rewards = samples_with_log_probs[perm]
samples = list(samples) samples = list(samples)
log_probs = list(log_probs) log_probs = list(log_probs)
for i in range(len(log_probs)):
log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
print(f'log_probs: {log_probs[:5]}') print(f'log_probs: {log_probs[:5]}')
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000]) print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
rewards = list(rewards) rewards = list(rewards)
log_probs = torch.cat(log_probs, dim=0)
print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
old_log_probs = log_probs.clone()
# multi metrics range
# reward hacking hiking
for inner_epoch in range(cfg.train.n_epochs): for inner_epoch in range(cfg.train.n_epochs):
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000]) # print(f'rewards: {rewards.shape}') # torch.Size([1000])
rewards = torch.cat(rewards, dim=0) print(f'rewards: {rewards[:5]}')
print(f'rewards: {rewards.shape}') print(f'len rewards: {len(rewards)}')
print(f'type rewards: {type(rewards)}')
if len(rewards) > 1 and isinstance(rewards, list):
rewards = torch.cat(rewards, dim=0)
elif len(rewards) == 1 and isinstance(rewards, list):
rewards = rewards[0]
# print(f'rewards: {rewards.shape}')
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
old_log_probs = log_probs.copy() print(f'advantages: {advantages.shape}')
with accelerator.accumulate(graph_dit_model): with accelerator.accumulate(graph_dit_model):
ratio = torch.exp(log_probs - old_log_probs) ratio = torch.exp(log_probs - old_log_probs)
unclipped_loss = -advantages * ratio unclipped_loss = -advantages * ratio
# z-score normalization
clipped_loss = -advantages * torch.clamp(ratio, clipped_loss = -advantages * torch.clamp(ratio,
1.0 - cfg.ppo.clip_param, 1.0 - cfg.ppo.clip_param,
1.0 + cfg.ppo.clip_param) 1.0 + cfg.ppo.clip_param)