need to update the model
This commit is contained in:
parent
97fbdf91c7
commit
0c60171c71
@ -601,7 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
assert (E == torch.transpose(E, 1, 2)).all()
|
assert (E == torch.transpose(E, 1, 2)).all()
|
||||||
|
|
||||||
total_log_probs = torch.zeros([1000,10], device=self.device)
|
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||||
|
|
||||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||||
for s_int in reversed(range(0, self.T)):
|
for s_int in reversed(range(0, self.T)):
|
||||||
|
@ -273,12 +273,14 @@ def test(cfg: DictConfig):
|
|||||||
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
ratio = cfg.general.final_model_samples_to_generate // num_examples
|
||||||
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
test_y_collection = test_y_collection.repeat(ratio+1, 1)
|
||||||
num_examples = test_y_collection.size(0)
|
num_examples = test_y_collection.size(0)
|
||||||
|
|
||||||
|
# Normal reward function
|
||||||
def graph_reward_fn(graphs, true_graphs=None, device=None):
|
def graph_reward_fn(graphs, true_graphs=None, device=None):
|
||||||
rewards = []
|
rewards = []
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
@ -317,19 +319,33 @@ def test(cfg: DictConfig):
|
|||||||
samples, log_probs, rewards = samples_with_log_probs[perm]
|
samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
samples = list(samples)
|
samples = list(samples)
|
||||||
log_probs = list(log_probs)
|
log_probs = list(log_probs)
|
||||||
|
for i in range(len(log_probs)):
|
||||||
|
log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
|
||||||
print(f'log_probs: {log_probs[:5]}')
|
print(f'log_probs: {log_probs[:5]}')
|
||||||
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
|
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
|
||||||
rewards = list(rewards)
|
rewards = list(rewards)
|
||||||
|
log_probs = torch.cat(log_probs, dim=0)
|
||||||
|
print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
||||||
|
old_log_probs = log_probs.clone()
|
||||||
|
|
||||||
|
# multi metrics range
|
||||||
|
# reward hacking hiking
|
||||||
for inner_epoch in range(cfg.train.n_epochs):
|
for inner_epoch in range(cfg.train.n_epochs):
|
||||||
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
|
# print(f'rewards: {rewards.shape}') # torch.Size([1000])
|
||||||
rewards = torch.cat(rewards, dim=0)
|
print(f'rewards: {rewards[:5]}')
|
||||||
print(f'rewards: {rewards.shape}')
|
print(f'len rewards: {len(rewards)}')
|
||||||
|
print(f'type rewards: {type(rewards)}')
|
||||||
|
if len(rewards) > 1 and isinstance(rewards, list):
|
||||||
|
rewards = torch.cat(rewards, dim=0)
|
||||||
|
elif len(rewards) == 1 and isinstance(rewards, list):
|
||||||
|
rewards = rewards[0]
|
||||||
|
# print(f'rewards: {rewards.shape}')
|
||||||
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
||||||
old_log_probs = log_probs.copy()
|
print(f'advantages: {advantages.shape}')
|
||||||
with accelerator.accumulate(graph_dit_model):
|
with accelerator.accumulate(graph_dit_model):
|
||||||
ratio = torch.exp(log_probs - old_log_probs)
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
unclipped_loss = -advantages * ratio
|
unclipped_loss = -advantages * ratio
|
||||||
|
# z-score normalization
|
||||||
clipped_loss = -advantages * torch.clamp(ratio,
|
clipped_loss = -advantages * torch.clamp(ratio,
|
||||||
1.0 - cfg.ppo.clip_param,
|
1.0 - cfg.ppo.clip_param,
|
||||||
1.0 + cfg.ppo.clip_param)
|
1.0 + cfg.ppo.clip_param)
|
||||||
|
Loading…
Reference in New Issue
Block a user