need to update the model
This commit is contained in:
parent
0c60171c71
commit
2ac17caa3c
@ -239,8 +239,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
||||||
|
|
||||||
if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
# if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]:
|
||||||
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
||||||
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll))
|
||||||
with open("validation-metrics.csv", "a") as f:
|
with open("validation-metrics.csv", "a") as f:
|
||||||
# save the metrics as csv file
|
# save the metrics as csv file
|
||||||
|
@ -242,6 +242,12 @@ def test(cfg: DictConfig):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# return {'loss': loss}
|
# return {'loss': loss}
|
||||||
|
if epoch % cfg.train.check_val_every_n_epoch == 0:
|
||||||
|
print(f'print validation loss')
|
||||||
|
graph_dit_model.eval()
|
||||||
|
graph_dit_model.on_validation_epoch_start()
|
||||||
|
graph_dit_model.validation_step(data, epoch)
|
||||||
|
graph_dit_model.on_validation_epoch_end()
|
||||||
|
|
||||||
# start testing
|
# start testing
|
||||||
print("start testing")
|
print("start testing")
|
||||||
@ -281,6 +287,53 @@ def test(cfg: DictConfig):
|
|||||||
reward = 1.0
|
reward = 1.0
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
|
# while samples_left_to_generate > 0:
|
||||||
|
# print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
|
# f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
|
# bs = 1 * cfg.train.batch_size
|
||||||
|
# to_generate = min(samples_left_to_generate, bs)
|
||||||
|
# to_save = min(samples_left_to_save, bs)
|
||||||
|
# chains_save = min(chains_left_to_save, bs)
|
||||||
|
# # batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
|
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
|
|
||||||
|
# cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
|
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
|
# samples = samples + cur_sample
|
||||||
|
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
|
|
||||||
|
# samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||||
|
|
||||||
|
# all_ys.append(batch_y)
|
||||||
|
# batch_id += to_generate
|
||||||
|
|
||||||
|
# samples_left_to_save -= to_save
|
||||||
|
# samples_left_to_generate -= to_generate
|
||||||
|
# chains_left_to_save -= chains_save
|
||||||
|
|
||||||
|
# print(f"final Computing sampling metrics...")
|
||||||
|
# graph_dit_model.sampling_metrics.reset()
|
||||||
|
# graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
||||||
|
# graph_dit_model.sampling_metrics.reset()
|
||||||
|
# print(f"Done.")
|
||||||
|
|
||||||
|
# # save samples
|
||||||
|
# print("Samples:")
|
||||||
|
# print(samples)
|
||||||
|
|
||||||
|
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
||||||
|
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
|
# samples = list(samples)
|
||||||
|
# log_probs = list(log_probs)
|
||||||
|
# for i in range(len(log_probs)):
|
||||||
|
# log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
|
||||||
|
# print(f'log_probs: {log_probs[:5]}')
|
||||||
|
# print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
|
||||||
|
# rewards = list(rewards)
|
||||||
|
# log_probs = torch.cat(log_probs, dim=0)
|
||||||
|
# print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
||||||
|
# old_log_probs = log_probs.clone()
|
||||||
|
old_log_probs = None
|
||||||
while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
@ -289,14 +342,34 @@ def test(cfg: DictConfig):
|
|||||||
to_save = min(samples_left_to_save, bs)
|
to_save = min(samples_left_to_save, bs)
|
||||||
chains_save = min(chains_left_to_save, bs)
|
chains_save = min(chains_left_to_save, bs)
|
||||||
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
|
||||||
|
|
||||||
cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
|
||||||
samples = samples + cur_sample
|
|
||||||
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
|
||||||
|
|
||||||
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
# cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
|
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
|
# samples = samples + cur_sample
|
||||||
|
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
|
# advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||||
|
with accelerator.accumulate(graph_dit_model):
|
||||||
|
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
|
new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
|
samples = samples + new_samples
|
||||||
|
reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
|
||||||
|
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||||
|
if old_log_probs is None:
|
||||||
|
old_log_probs = log_probs.clone()
|
||||||
|
ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
unclipped_loss = -advantages * ratio
|
||||||
|
clipped_loss = -advantages * torch.clamp(ratio,
|
||||||
|
1.0 - cfg.ppo.clip_param,
|
||||||
|
1.0 + cfg.ppo.clip_param)
|
||||||
|
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
samples_with_log_probs.append((new_samples, log_probs, reward))
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
batch_id += to_generate
|
batch_id += to_generate
|
||||||
@ -304,6 +377,7 @@ def test(cfg: DictConfig):
|
|||||||
samples_left_to_save -= to_save
|
samples_left_to_save -= to_save
|
||||||
samples_left_to_generate -= to_generate
|
samples_left_to_generate -= to_generate
|
||||||
chains_left_to_save -= chains_save
|
chains_left_to_save -= chains_save
|
||||||
|
# break
|
||||||
|
|
||||||
print(f"final Computing sampling metrics...")
|
print(f"final Computing sampling metrics...")
|
||||||
graph_dit_model.sampling_metrics.reset()
|
graph_dit_model.sampling_metrics.reset()
|
||||||
@ -315,47 +389,46 @@ def test(cfg: DictConfig):
|
|||||||
print("Samples:")
|
print("Samples:")
|
||||||
print(samples)
|
print(samples)
|
||||||
|
|
||||||
perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
||||||
samples, log_probs, rewards = samples_with_log_probs[perm]
|
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
samples = list(samples)
|
# samples = list(samples)
|
||||||
log_probs = list(log_probs)
|
# log_probs = list(log_probs)
|
||||||
for i in range(len(log_probs)):
|
# for i in range(len(log_probs)):
|
||||||
log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
|
# log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0)
|
||||||
print(f'log_probs: {log_probs[:5]}')
|
# print(f'log_probs: {log_probs[:5]}')
|
||||||
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
|
# print(f'log_probs: {log_probs[0].shape}') # torch.Size([1])
|
||||||
rewards = list(rewards)
|
# rewards = list(rewards)
|
||||||
log_probs = torch.cat(log_probs, dim=0)
|
# log_probs = torch.cat(log_probs, dim=0)
|
||||||
print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
# print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
||||||
old_log_probs = log_probs.clone()
|
# old_log_probs = log_probs.clone()
|
||||||
|
# # multi metrics range
|
||||||
|
# # reward hacking hiking
|
||||||
|
# for inner_epoch in range(cfg.train.n_epochs):
|
||||||
|
# # print(f'rewards: {rewards.shape}') # torch.Size([1000])
|
||||||
|
# print(f'rewards: {rewards[:5]}')
|
||||||
|
# print(f'len rewards: {len(rewards)}')
|
||||||
|
# print(f'type rewards: {type(rewards)}')
|
||||||
|
# if len(rewards) > 1 and isinstance(rewards, list):
|
||||||
|
# rewards = torch.cat(rewards, dim=0)
|
||||||
|
# elif len(rewards) == 1 and isinstance(rewards, list):
|
||||||
|
# rewards = rewards[0]
|
||||||
|
# # print(f'rewards: {rewards.shape}')
|
||||||
|
# advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
||||||
|
# print(f'advantages: {advantages.shape}')
|
||||||
|
# with accelerator.accumulate(graph_dit_model):
|
||||||
|
# ratio = torch.exp(log_probs - old_log_probs)
|
||||||
|
# unclipped_loss = -advantages * ratio
|
||||||
|
# # z-score normalization
|
||||||
|
# clipped_loss = -advantages * torch.clamp(ratio,
|
||||||
|
# 1.0 - cfg.ppo.clip_param,
|
||||||
|
# 1.0 + cfg.ppo.clip_param)
|
||||||
|
# loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
# accelerator.backward(loss)
|
||||||
|
# optimizer.step()
|
||||||
|
# optimizer.zero_grad()
|
||||||
|
|
||||||
# multi metrics range
|
# accelerator.log({"loss": loss.item(), "epoch": inner_epoch})
|
||||||
# reward hacking hiking
|
# print(f"loss: {loss.item()}, epoch: {inner_epoch}")
|
||||||
for inner_epoch in range(cfg.train.n_epochs):
|
|
||||||
# print(f'rewards: {rewards.shape}') # torch.Size([1000])
|
|
||||||
print(f'rewards: {rewards[:5]}')
|
|
||||||
print(f'len rewards: {len(rewards)}')
|
|
||||||
print(f'type rewards: {type(rewards)}')
|
|
||||||
if len(rewards) > 1 and isinstance(rewards, list):
|
|
||||||
rewards = torch.cat(rewards, dim=0)
|
|
||||||
elif len(rewards) == 1 and isinstance(rewards, list):
|
|
||||||
rewards = rewards[0]
|
|
||||||
# print(f'rewards: {rewards.shape}')
|
|
||||||
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
|
|
||||||
print(f'advantages: {advantages.shape}')
|
|
||||||
with accelerator.accumulate(graph_dit_model):
|
|
||||||
ratio = torch.exp(log_probs - old_log_probs)
|
|
||||||
unclipped_loss = -advantages * ratio
|
|
||||||
# z-score normalization
|
|
||||||
clipped_loss = -advantages * torch.clamp(ratio,
|
|
||||||
1.0 - cfg.ppo.clip_param,
|
|
||||||
1.0 + cfg.ppo.clip_param)
|
|
||||||
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
|
||||||
accelerator.backward(loss)
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
accelerator.log({"loss": loss.item(), "epoch": inner_epoch})
|
|
||||||
print(f"loss: {loss.item()}, epoch: {inner_epoch}")
|
|
||||||
|
|
||||||
|
|
||||||
# trainer = Trainer(
|
# trainer = Trainer(
|
||||||
|
Loading…
Reference in New Issue
Block a user