diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index e056c51..310131c 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule): assert (E == torch.transpose(E, 1, 2)).all() + # total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. diff --git a/graph_dit/main.py b/graph_dit/main.py index 33fd0dd..9901522 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -281,45 +281,70 @@ def test(cfg: DictConfig): num_examples = test_y_collection.size(0) # Normal reward function - def graph_reward_fn(graphs, true_graphs=None, device=None): + from nas_201_api import NASBench201API as API + api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth') + def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): rewards = [] + if reward_model == 'swap': + import csv + with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: + reader = csv.reader(f) + header = next(reader) + data = [row for row in reader] + swap_scores = [float(row[0]) for row in data] + for graph in graphs: + node_tensor = graph[0] + node = node_tensor.cpu().numpy().tolist() + + def nodes_to_arch_str(nodes): + num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + nodes_str = [num_to_op[node] for node in nodes] + arch_str = '|' + nodes_str[1] + '~0|+' + \ + '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ + '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|' + return arch_str + + arch_str = nodes_to_arch_str(node) + reward = swap_scores[api.query_index_by_arch(arch_str)] + rewards.append(reward) + for graph in graphs: reward = 1.0 rewards.append(reward) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) - # while samples_left_to_generate > 0: - # print(f'samples left to generate: {samples_left_to_generate}/' - # f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) - # bs = 1 * cfg.train.batch_size - # to_generate = min(samples_left_to_generate, bs) - # to_save = min(samples_left_to_save, bs) - # chains_save = min(chains_left_to_save, bs) - # # batch_y = test_y_collection[batch_id : batch_id + to_generate] - # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) + while samples_left_to_generate > 0: + print(f'samples left to generate: {samples_left_to_generate}/' + f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) + bs = 1 * cfg.train.batch_size + to_generate = min(samples_left_to_generate, bs) + to_save = min(samples_left_to_save, bs) + chains_save = min(chains_left_to_save, bs) + # batch_y = test_y_collection[batch_id : batch_id + to_generate] + batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - # cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, - # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - # samples = samples + cur_sample - # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, + keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + samples = samples + cur_sample + reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) - # samples_with_log_probs.append((cur_sample, log_probs, reward)) + samples_with_log_probs.append((cur_sample, log_probs, reward)) - # all_ys.append(batch_y) - # batch_id += to_generate + all_ys.append(batch_y) + batch_id += to_generate - # samples_left_to_save -= to_save - # samples_left_to_generate -= to_generate - # chains_left_to_save -= chains_save + samples_left_to_save -= to_save + samples_left_to_generate -= to_generate + chains_left_to_save -= chains_save - # print(f"final Computing sampling metrics...") - # graph_dit_model.sampling_metrics.reset() - # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) - # graph_dit_model.sampling_metrics.reset() - # print(f"Done.") + print(f"final Computing sampling metrics...") + graph_dit_model.sampling_metrics.reset() + graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) + graph_dit_model.sampling_metrics.reset() + print(f"Done.") - # # save samples - # print("Samples:") - # print(samples) + # save samples + print("Samples:") + print(samples) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # samples, log_probs, rewards = samples_with_log_probs[perm] @@ -333,61 +358,56 @@ def test(cfg: DictConfig): # log_probs = torch.cat(log_probs, dim=0) # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) # old_log_probs = log_probs.clone() - old_log_probs = None - while samples_left_to_generate > 0: - print(f'samples left to generate: {samples_left_to_generate}/' - f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) - bs = 1 * cfg.train.batch_size - to_generate = min(samples_left_to_generate, bs) - to_save = min(samples_left_to_save, bs) - chains_save = min(chains_left_to_save, bs) - # batch_y = test_y_collection[batch_id : batch_id + to_generate] - # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) + # === - # cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, - # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - # samples = samples + cur_sample - # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) - # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) - with accelerator.accumulate(graph_dit_model): - batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) - new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) - samples = samples + new_samples - reward = graph_reward_fn(new_samples, device=graph_dit_model.device) - advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) - if old_log_probs is None: - old_log_probs = log_probs.clone() - ratio = torch.exp(log_probs - old_log_probs) - unclipped_loss = -advantages * ratio - clipped_loss = -advantages * torch.clamp(ratio, - 1.0 - cfg.ppo.clip_param, - 1.0 + cfg.ppo.clip_param) - loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() + # old_log_probs = None + # while samples_left_to_generate > 0: + # print(f'samples left to generate: {samples_left_to_generate}/' + # f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) + # bs = 1 * cfg.train.batch_size + # to_generate = min(samples_left_to_generate, bs) + # to_save = min(samples_left_to_save, bs) + # chains_save = min(chains_left_to_save, bs) + + # with accelerator.accumulate(graph_dit_model): + # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) + # new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) + # samples = samples + new_samples + # reward = graph_reward_fn(new_samples, device=graph_dit_model.device) + # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + # if old_log_probs is None: + # old_log_probs = log_probs.clone() + # ratio = torch.exp(log_probs - old_log_probs) + # unclipped_loss = -advantages * ratio + # clipped_loss = -advantages * torch.clamp(ratio, + # 1.0 - cfg.ppo.clip_param, + # 1.0 + cfg.ppo.clip_param) + # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) + # accelerator.backward(loss) + # optimizer.step() + # optimizer.zero_grad() - samples_with_log_probs.append((new_samples, log_probs, reward)) + # samples_with_log_probs.append((new_samples, log_probs, reward)) - all_ys.append(batch_y) - batch_id += to_generate + # all_ys.append(batch_y) + # batch_id += to_generate - samples_left_to_save -= to_save - samples_left_to_generate -= to_generate - chains_left_to_save -= chains_save - # break + # samples_left_to_save -= to_save + # samples_left_to_generate -= to_generate + # chains_left_to_save -= chains_save + # # break - print(f"final Computing sampling metrics...") - graph_dit_model.sampling_metrics.reset() - graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) - graph_dit_model.sampling_metrics.reset() - print(f"Done.") + # print(f"final Computing sampling metrics...") + # graph_dit_model.sampling_metrics.reset() + # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) + # graph_dit_model.sampling_metrics.reset() + # print(f"Done.") - # save samples - print("Samples:") - print(samples) + # # save samples + # print("Samples:") + # print(samples) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # samples, log_probs, rewards = samples_with_log_probs[perm]