can run but need to test whtich pth is
This commit is contained in:
parent
74a629fdcc
commit
1ad520d248
@ -282,7 +282,7 @@ def test(cfg: DictConfig):
|
||||
|
||||
# Normal reward function
|
||||
from nas_201_api import NASBench201API as API
|
||||
api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||
api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
|
||||
rewards = []
|
||||
if reward_model == 'swap':
|
||||
@ -308,9 +308,9 @@ def test(cfg: DictConfig):
|
||||
reward = swap_scores[api.query_index_by_arch(arch_str)]
|
||||
rewards.append(reward)
|
||||
|
||||
for graph in graphs:
|
||||
reward = 1.0
|
||||
rewards.append(reward)
|
||||
# for graph in graphs:
|
||||
# reward = 1.0
|
||||
# rewards.append(reward)
|
||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||
while samples_left_to_generate > 0:
|
||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||
@ -326,6 +326,8 @@ def test(cfg: DictConfig):
|
||||
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||
samples = samples + cur_sample
|
||||
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||
|
||||
|
||||
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user