From 91d4e3c7ad15d09c619bfdaa320c3f8426dab916 Mon Sep 17 00:00:00 2001 From: mhz Date: Mon, 16 Sep 2024 22:45:12 +0200 Subject: [PATCH] try to get the original perf --- graph_dit/diffusion_model.py | 11 +- graph_dit/main.py | 283 ++++++++++++----------------------- 2 files changed, 105 insertions(+), 189 deletions(-) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index 310131c..c246862 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule): # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) def on_train_epoch_start(self) -> None: - if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: - print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) + # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: + # print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) + print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs)) self.start_epoch_time = time.time() self.train_loss.reset() self.train_metrics.reset() def on_train_epoch_end(self) -> None: - if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: + if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: log = True else: log = False @@ -601,8 +604,8 @@ class Graph_DiT(pl.LightningModule): assert (E == torch.transpose(E, 1, 2)).all() - # total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) + # total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device) # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. for s_int in reversed(range(0, self.T)): diff --git a/graph_dit/main.py b/graph_dit/main.py index 7023a24..586ee1b 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -161,7 +161,8 @@ def test(cfg: DictConfig): accelerator = Accelerator( mixed_precision='no', project_config=accelerator_config, - gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, + # gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, + gradient_accumulation_steps=cfg.train.gradient_accumulation_steps, ) # Debug: 确认可用设备 @@ -219,29 +220,34 @@ def test(cfg: DictConfig): for epoch in range(cfg.train.n_epochs): graph_dit_model.train() # 设置模型为训练模式 print(f"Epoch {epoch}", end="\n") + graph_dit_model.on_train_epoch_start() for data in train_dataloader: # 从数据加载器中获取一个批次的数据 - data.to(accelerator.device) - data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() - dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) - dense_data = dense_data.mask(node_mask) - X, E = dense_data.X, dense_data.E - noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) - pred = graph_dit_model.forward(noisy_data) - loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, - true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, - log=epoch % graph_dit_model.log_every_steps == 0) - # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') - graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, - log=epoch % graph_dit_model.log_every_steps == 0) - graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) - print(f"training loss: {loss}") - with open("training-loss.csv", "a") as f: - f.write(f"{loss}, {epoch}\n") + # data.to(accelerator.device) + # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] + # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() + # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) + # dense_data = dense_data.mask(node_mask) + # X, E = dense_data.X, dense_data.E + # noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) + # pred = graph_dit_model.forward(noisy_data) + # loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, + # true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, + # log=epoch % graph_dit_model.log_every_steps == 0) + # # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') + # graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, + # log=epoch % graph_dit_model.log_every_steps == 0) + # graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) + # print(f"training loss: {loss}") + # with open("training-loss.csv", "a") as f: + # f.write(f"{loss}, {epoch}\n") + loss = graph_dit_model.training_step(data, epoch) + loss = loss['loss'] + accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # return {'loss': loss} + graph_dit_model.on_train_epoch_end() if epoch % cfg.train.check_val_every_n_epoch == 0: print(f'print validation loss') graph_dit_model.eval() @@ -253,126 +259,69 @@ def test(cfg: DictConfig): print("start testing") graph_dit_model.eval() test_dataloader = accelerator.prepare(datamodule.test_dataloader()) + graph_dit_model.on_test_epoch_start() for data in test_dataloader: - data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] - data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() + nll = graph_dit_model.test_step(data, epoch) + # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] + # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() - dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) - dense_data = dense_data.mask(node_mask) - noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) - pred = graph_dit_model.forward(noisy_data) - nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) - graph_dit_model.test_y_collection.append(data.y) + # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) + # dense_data = dense_data.mask(node_mask) + # noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) + # pred = graph_dit_model.forward(noisy_data) + # nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) + # graph_dit_model.test_y_collection.append(data.y) print(f'test loss: {nll}') + + graph_dit_model.on_test_epoch_end() # start sampling - samples_left_to_generate = cfg.general.final_model_samples_to_generate - samples_left_to_save = cfg.general.final_model_samples_to_save - chains_left_to_save = cfg.general.final_model_chains_to_save + # samples_left_to_generate = cfg.general.final_model_samples_to_generate + # samples_left_to_save = cfg.general.final_model_samples_to_save + # chains_left_to_save = cfg.general.final_model_chains_to_save - samples, all_ys, batch_id = [], [], 0 - samples_with_log_probs = [] - test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) - num_examples = test_y_collection.size(0) - if cfg.general.final_model_samples_to_generate > num_examples: - ratio = cfg.general.final_model_samples_to_generate // num_examples - test_y_collection = test_y_collection.repeat(ratio+1, 1) - num_examples = test_y_collection.size(0) + # samples, all_ys, batch_id = [], [], 0 + # samples_with_log_probs = [] + # test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) + # num_examples = test_y_collection.size(0) + # if cfg.general.final_model_samples_to_generate > num_examples: + # ratio = cfg.general.final_model_samples_to_generate // num_examples + # test_y_collection = test_y_collection.repeat(ratio+1, 1) + # num_examples = test_y_collection.size(0) # Normal reward function - from nas_201_api import NASBench201API as API - api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') - def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): - rewards = [] - if reward_model == 'swap': - import csv - with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: - reader = csv.reader(f) - header = next(reader) - data = [row for row in reader] - swap_scores = [float(row[0]) for row in data] - for graph in graphs: - node_tensor = graph[0] - node = node_tensor.cpu().numpy().tolist() + # from nas_201_api import NASBench201API as API + # api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') + # def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): + # rewards = [] + # if reward_model == 'swap': + # import csv + # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + # reader = csv.reader(f) + # header = next(reader) + # data = [row for row in reader] + # swap_scores = [float(row[0]) for row in data] + # for graph in graphs: + # node_tensor = graph[0] + # node = node_tensor.cpu().numpy().tolist() - def nodes_to_arch_str(nodes): - num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] - nodes_str = [num_to_op[node] for node in nodes] - arch_str = '|' + nodes_str[1] + '~0|+' + \ - '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ - '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|' - return arch_str + # def nodes_to_arch_str(nodes): + # num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + # nodes_str = [num_to_op[node] for node in nodes] + # arch_str = '|' + nodes_str[1] + '~0|+' + \ + # '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ + # '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|' + # return arch_str - arch_str = nodes_to_arch_str(node) - reward = swap_scores[api.query_index_by_arch(arch_str)] - rewards.append(reward) + # arch_str = nodes_to_arch_str(node) + # reward = swap_scores[api.query_index_by_arch(arch_str)] + # rewards.append(reward) - # for graph in graphs: - # reward = 1.0 - # rewards.append(reward) - return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) - old_log_probs = None - while samples_left_to_generate > 0: - print(f'samples left to generate: {samples_left_to_generate}/' - f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) - bs = 1 * cfg.train.batch_size - to_generate = min(samples_left_to_generate, bs) - to_save = min(samples_left_to_save, bs) - chains_save = min(chains_left_to_save, bs) - # batch_y = test_y_collection[batch_id : batch_id + to_generate] - batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - - cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, - keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - samples = samples + cur_sample - reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) - advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) - if old_log_probs is None: - old_log_probs = log_probs.clone() - ratio = torch.exp(log_probs - old_log_probs) - unclipped_loss = -advantages * ratio - clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) - loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - - - samples_with_log_probs.append((cur_sample, log_probs, reward)) - - all_ys.append(batch_y) - batch_id += to_generate - - samples_left_to_save -= to_save - samples_left_to_generate -= to_generate - chains_left_to_save -= chains_save - - print(f"final Computing sampling metrics...") - graph_dit_model.sampling_metrics.reset() - graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) - graph_dit_model.sampling_metrics.reset() - print(f"Done.") - - # save samples - print("Samples:") - print(samples) - - # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) - # samples, log_probs, rewards = samples_with_log_probs[perm] - # samples = list(samples) - # log_probs = list(log_probs) - # for i in range(len(log_probs)): - # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) - # print(f'log_probs: {log_probs[:5]}') - # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) - # rewards = list(rewards) - # log_probs = torch.cat(log_probs, dim=0) - # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) - # old_log_probs = log_probs.clone() - - # === - + # # for graph in graphs: + # # reward = 1.0 + # # rewards.append(reward) + # return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) # old_log_probs = None # while samples_left_to_generate > 0: # print(f'samples left to generate: {samples_left_to_generate}/' @@ -381,27 +330,28 @@ def test(cfg: DictConfig): # to_generate = min(samples_left_to_generate, bs) # to_save = min(samples_left_to_save, bs) # chains_save = min(chains_left_to_save, bs) + # # batch_y = test_y_collection[batch_id : batch_id + to_generate] + # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - # with accelerator.accumulate(graph_dit_model): - # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - # new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - # samples = samples + new_samples - # reward = graph_reward_fn(new_samples, device=graph_dit_model.device) - # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) - # if old_log_probs is None: - # old_log_probs = log_probs.clone() - # ratio = torch.exp(log_probs - old_log_probs) - # unclipped_loss = -advantages * ratio - # clipped_loss = -advantages * torch.clamp(ratio, - # 1.0 - cfg.ppo.clip_param, - # 1.0 + cfg.ppo.clip_param) - # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) - # accelerator.backward(loss) - # optimizer.step() - # optimizer.zero_grad() + # cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + # log_probs = torch.sum(log_probs, dim=-1).unsqueeze(1) + # samples = samples + cur_sample + # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + # print(f'reward: {reward.shape}, advantages: {advantages.shape}, log_probs: {log_probs.shape}, cur_sample: {len(cur_sample)}') + # if old_log_probs is None: + # old_log_probs = log_probs.clone() + # ratio = torch.exp(log_probs - old_log_probs) + # unclipped_loss = -advantages * ratio + # clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) + # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + # accelerator.backward(loss) + # optimizer.step() + # optimizer.zero_grad() - # samples_with_log_probs.append((new_samples, log_probs, reward)) + # samples_with_log_probs.append((cur_sample, log_probs, reward)) # all_ys.append(batch_y) # batch_id += to_generate @@ -409,7 +359,6 @@ def test(cfg: DictConfig): # samples_left_to_save -= to_save # samples_left_to_generate -= to_generate # chains_left_to_save -= chains_save - # # break # print(f"final Computing sampling metrics...") # graph_dit_model.sampling_metrics.reset() @@ -421,46 +370,10 @@ def test(cfg: DictConfig): # print("Samples:") # print(samples) - # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) - # samples, log_probs, rewards = samples_with_log_probs[perm] - # samples = list(samples) - # log_probs = list(log_probs) - # for i in range(len(log_probs)): - # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) - # print(f'log_probs: {log_probs[:5]}') - # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) - # rewards = list(rewards) - # log_probs = torch.cat(log_probs, dim=0) - # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) - # old_log_probs = log_probs.clone() - # # multi metrics range - # # reward hacking hiking - # for inner_epoch in range(cfg.train.n_epochs): - # # print(f'rewards: {rewards.shape}') # torch.Size([1000]) - # print(f'rewards: {rewards[:5]}') - # print(f'len rewards: {len(rewards)}') - # print(f'type rewards: {type(rewards)}') - # if len(rewards) > 1 and isinstance(rewards, list): - # rewards = torch.cat(rewards, dim=0) - # elif len(rewards) == 1 and isinstance(rewards, list): - # rewards = rewards[0] - # # print(f'rewards: {rewards.shape}') - # advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) - # print(f'advantages: {advantages.shape}') - # with accelerator.accumulate(graph_dit_model): - # ratio = torch.exp(log_probs - old_log_probs) - # unclipped_loss = -advantages * ratio - # # z-score normalization - # clipped_loss = -advantages * torch.clamp(ratio, - # 1.0 - cfg.ppo.clip_param, - # 1.0 + cfg.ppo.clip_param) - # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) - # accelerator.backward(loss) - # optimizer.step() - # optimizer.zero_grad() + # ======================== + - # accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) - # print(f"loss: {loss.item()}, epoch: {inner_epoch}") + # trainer = Trainer(