diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 340fb3e..e056c51 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -239,8 +239,8 @@ class Graph_DiT(pl.LightningModule): self.val_X_logp.compute(), self.val_E_logp.compute()] - if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: - print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", + # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best : %.2f\n' % (metrics[0], self.best_val_nll)) with open("validation-metrics.csv", "a") as f: # save the metrics as csv file diff --git a/graph_dit/main.py b/graph_dit/main.py index 12677d5..33fd0dd 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -242,6 +242,12 @@ def test(cfg: DictConfig): optimizer.step() optimizer.zero_grad() # return {'loss': loss} + if epoch % cfg.train.check_val_every_n_epoch == 0: + print(f'print validation loss') + graph_dit_model.eval() + graph_dit_model.on_validation_epoch_start() + graph_dit_model.validation_step(data, epoch) + graph_dit_model.on_validation_epoch_end() # start testing print("start testing") @@ -281,6 +287,53 @@ def test(cfg: DictConfig): reward = 1.0 rewards.append(reward) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) + # while samples_left_to_generate > 0: + # print(f'samples left to generate: {samples_left_to_generate}/' + # f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) + # bs = 1 * cfg.train.batch_size + # to_generate = min(samples_left_to_generate, bs) + # to_save = min(samples_left_to_save, bs) + # chains_save = min(chains_left_to_save, bs) + # # batch_y = test_y_collection[batch_id : batch_id + to_generate] + # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) + + # cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + # samples = samples + cur_sample + # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + + # samples_with_log_probs.append((cur_sample, log_probs, reward)) + + # all_ys.append(batch_y) + # batch_id += to_generate + + # samples_left_to_save -= to_save + # samples_left_to_generate -= to_generate + # chains_left_to_save -= chains_save + + # print(f"final Computing sampling metrics...") + # graph_dit_model.sampling_metrics.reset() + # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) + # graph_dit_model.sampling_metrics.reset() + # print(f"Done.") + + # # save samples + # print("Samples:") + # print(samples) + + # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) + # samples, log_probs, rewards = samples_with_log_probs[perm] + # samples = list(samples) + # log_probs = list(log_probs) + # for i in range(len(log_probs)): + # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) + # print(f'log_probs: {log_probs[:5]}') + # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) + # rewards = list(rewards) + # log_probs = torch.cat(log_probs, dim=0) + # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) + # old_log_probs = log_probs.clone() + old_log_probs = None while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) @@ -289,14 +342,34 @@ def test(cfg: DictConfig): to_save = min(samples_left_to_save, bs) chains_save = min(chains_left_to_save, bs) # batch_y = test_y_collection[batch_id : batch_id + to_generate] - batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, - keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - samples = samples + cur_sample - reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - samples_with_log_probs.append((cur_sample, log_probs, reward)) + # cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + # samples = samples + cur_sample + # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + with accelerator.accumulate(graph_dit_model): + batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) + new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + samples = samples + new_samples + reward = graph_reward_fn(new_samples, device=graph_dit_model.device) + advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + if old_log_probs is None: + old_log_probs = log_probs.clone() + ratio = torch.exp(log_probs - old_log_probs) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp(ratio, + 1.0 - cfg.ppo.clip_param, + 1.0 + cfg.ppo.clip_param) + loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + + samples_with_log_probs.append((new_samples, log_probs, reward)) all_ys.append(batch_y) batch_id += to_generate @@ -304,6 +377,7 @@ def test(cfg: DictConfig): samples_left_to_save -= to_save samples_left_to_generate -= to_generate chains_left_to_save -= chains_save + # break print(f"final Computing sampling metrics...") graph_dit_model.sampling_metrics.reset() @@ -315,47 +389,46 @@ def test(cfg: DictConfig): print("Samples:") print(samples) - perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) - samples, log_probs, rewards = samples_with_log_probs[perm] - samples = list(samples) - log_probs = list(log_probs) - for i in range(len(log_probs)): - log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) - print(f'log_probs: {log_probs[:5]}') - print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) - rewards = list(rewards) - log_probs = torch.cat(log_probs, dim=0) - print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) - old_log_probs = log_probs.clone() + # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) + # samples, log_probs, rewards = samples_with_log_probs[perm] + # samples = list(samples) + # log_probs = list(log_probs) + # for i in range(len(log_probs)): + # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) + # print(f'log_probs: {log_probs[:5]}') + # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) + # rewards = list(rewards) + # log_probs = torch.cat(log_probs, dim=0) + # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) + # old_log_probs = log_probs.clone() + # # multi metrics range + # # reward hacking hiking + # for inner_epoch in range(cfg.train.n_epochs): + # # print(f'rewards: {rewards.shape}') # torch.Size([1000]) + # print(f'rewards: {rewards[:5]}') + # print(f'len rewards: {len(rewards)}') + # print(f'type rewards: {type(rewards)}') + # if len(rewards) > 1 and isinstance(rewards, list): + # rewards = torch.cat(rewards, dim=0) + # elif len(rewards) == 1 and isinstance(rewards, list): + # rewards = rewards[0] + # # print(f'rewards: {rewards.shape}') + # advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) + # print(f'advantages: {advantages.shape}') + # with accelerator.accumulate(graph_dit_model): + # ratio = torch.exp(log_probs - old_log_probs) + # unclipped_loss = -advantages * ratio + # # z-score normalization + # clipped_loss = -advantages * torch.clamp(ratio, + # 1.0 - cfg.ppo.clip_param, + # 1.0 + cfg.ppo.clip_param) + # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + # accelerator.backward(loss) + # optimizer.step() + # optimizer.zero_grad() - # multi metrics range - # reward hacking hiking - for inner_epoch in range(cfg.train.n_epochs): - # print(f'rewards: {rewards.shape}') # torch.Size([1000]) - print(f'rewards: {rewards[:5]}') - print(f'len rewards: {len(rewards)}') - print(f'type rewards: {type(rewards)}') - if len(rewards) > 1 and isinstance(rewards, list): - rewards = torch.cat(rewards, dim=0) - elif len(rewards) == 1 and isinstance(rewards, list): - rewards = rewards[0] - # print(f'rewards: {rewards.shape}') - advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) - print(f'advantages: {advantages.shape}') - with accelerator.accumulate(graph_dit_model): - ratio = torch.exp(log_probs - old_log_probs) - unclipped_loss = -advantages * ratio - # z-score normalization - clipped_loss = -advantages * torch.clamp(ratio, - 1.0 - cfg.ppo.clip_param, - 1.0 + cfg.ppo.clip_param) - loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - - accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) - print(f"loss: {loss.item()}, epoch: {inner_epoch}") + # accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) + # print(f"loss: {loss.item()}, epoch: {inner_epoch}") # trainer = Trainer(