diff --git a/graph_dit/main.py b/graph_dit/main.py index 9901522..269c008 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -282,7 +282,7 @@ def test(cfg: DictConfig): # Normal reward function from nas_201_api import NASBench201API as API - api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth') + api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): rewards = [] if reward_model == 'swap': @@ -308,9 +308,9 @@ def test(cfg: DictConfig): reward = swap_scores[api.query_index_by_arch(arch_str)] rewards.append(reward) - for graph in graphs: - reward = 1.0 - rewards.append(reward) + # for graph in graphs: + # reward = 1.0 + # rewards.append(reward) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' @@ -326,6 +326,8 @@ def test(cfg: DictConfig): keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) samples = samples + cur_sample reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) + advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) + samples_with_log_probs.append((cur_sample, log_probs, reward))