can run but need to test whtich pth is

This commit is contained in:
mhz 2024-09-15 22:18:56 +02:00
parent 74a629fdcc
commit 1ad520d248

View File

@ -282,7 +282,7 @@ def test(cfg: DictConfig):
# Normal reward function # Normal reward function
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API as API
api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth') api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
rewards = [] rewards = []
if reward_model == 'swap': if reward_model == 'swap':
@ -308,9 +308,9 @@ def test(cfg: DictConfig):
reward = swap_scores[api.query_index_by_arch(arch_str)] reward = swap_scores[api.query_index_by_arch(arch_str)]
rewards.append(reward) rewards.append(reward)
for graph in graphs: # for graph in graphs:
reward = 1.0 # reward = 1.0
rewards.append(reward) # rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
while samples_left_to_generate > 0: while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/' print(f'samples left to generate: {samples_left_to_generate}/'
@ -326,6 +326,8 @@ def test(cfg: DictConfig):
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
samples = samples + cur_sample samples = samples + cur_sample
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
samples_with_log_probs.append((cur_sample, log_probs, reward)) samples_with_log_probs.append((cur_sample, log_probs, reward))