can run but need to test whtich pth is
This commit is contained in:
parent
74a629fdcc
commit
1ad520d248
@ -282,7 +282,7 @@ def test(cfg: DictConfig):
|
|||||||
|
|
||||||
# Normal reward function
|
# Normal reward function
|
||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||||
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
|
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
|
||||||
rewards = []
|
rewards = []
|
||||||
if reward_model == 'swap':
|
if reward_model == 'swap':
|
||||||
@ -308,9 +308,9 @@ def test(cfg: DictConfig):
|
|||||||
reward = swap_scores[api.query_index_by_arch(arch_str)]
|
reward = swap_scores[api.query_index_by_arch(arch_str)]
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
|
|
||||||
for graph in graphs:
|
# for graph in graphs:
|
||||||
reward = 1.0
|
# reward = 1.0
|
||||||
rewards.append(reward)
|
# rewards.append(reward)
|
||||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
@ -326,6 +326,8 @@ def test(cfg: DictConfig):
|
|||||||
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
samples = samples + cur_sample
|
samples = samples + cur_sample
|
||||||
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
|
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||||
|
|
||||||
|
|
||||||
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user