need to update the model
This commit is contained in:
parent
97fbdf91c7
commit
0c60171c71
@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
assert (E == torch.transpose(E, 1, 2)).all()
|
||||
|
||||
total_log_probs = torch.zeros([1000,10], device=self.device)
|
||||
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||
|
||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||
for s_int in reversed(range(0, self.T)):
|
||||
|
@ -273,12 +273,14 @@ def test(cfg: DictConfig):
|
||||
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
||||
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
||||
num_examples = test_y_collection.size(0)
|
||||
|
||||
# Normal reward function
|
||||
def graph_reward_fn(graphs, true_graphs=None, device=None):
|
||||
rewards = []
|
||||
for graph in graphs:
|
||||
reward = 1.0
|
||||
rewards.append(reward)
|
||||
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||
while samples_left_to_generate > 0:
|
||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||
@ -317,19 +319,33 @@ def test(cfg: DictConfig):
|
||||
samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||
samples = list(samples)
|
||||
log_probs = list(log_probs)
|
||||
for i in range(len(log_probs)):
|
||||
log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
|
||||
print(f'log_probs: {log_probs[:5]}')
|
||||
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
|
||||
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
|
||||
rewards = list(rewards)
|
||||
log_probs = torch.cat(log_probs, dim=0)
|
||||
print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
||||
old_log_probs = log_probs.clone()
|
||||
|
||||
# multi metrics range
|
||||
# reward hacking hiking
|
||||
for inner_epoch in range(cfg.train.n_epochs):
|
||||
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
|
||||
rewards = torch.cat(rewards, dim=0)
|
||||
print(f'rewards: {rewards.shape}')
|
||||
# print(f'rewards: {rewards.shape}') # torch.Size([1000])
|
||||
print(f'rewards: {rewards[:5]}')
|
||||
print(f'len rewards: {len(rewards)}')
|
||||
print(f'type rewards: {type(rewards)}')
|
||||
if len(rewards) > 1 and isinstance(rewards, list):
|
||||
rewards = torch.cat(rewards, dim=0)
|
||||
elif len(rewards) == 1 and isinstance(rewards, list):
|
||||
rewards = rewards[0]
|
||||
# print(f'rewards: {rewards.shape}')
|
||||
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
||||
old_log_probs = log_probs.copy()
|
||||
print(f'advantages: {advantages.shape}')
|
||||
with accelerator.accumulate(graph_dit_model):
|
||||
ratio = torch.exp(log_probs - old_log_probs)
|
||||
unclipped_loss = -advantages * ratio
|
||||
# z-score normalization
|
||||
clipped_loss = -advantages * torch.clamp(ratio,
|
||||
1.0 - cfg.ppo.clip_param,
|
||||
1.0 + cfg.ppo.clip_param)
|
||||
|
Loading…
Reference in New Issue
Block a user