try to update reward func
This commit is contained in:
parent
2ac17caa3c
commit
94fe13756f
@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
|
|
||||||
assert (E == torch.transpose(E, 1, 2)).all()
|
assert (E == torch.transpose(E, 1, 2)).all()
|
||||||
|
|
||||||
|
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||||
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
|
||||||
|
|
||||||
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
|
||||||
|
@ -281,45 +281,70 @@ def test(cfg: DictConfig):
|
|||||||
num_examples = test_y_collection.size(0)
|
num_examples = test_y_collection.size(0)
|
||||||
|
|
||||||
# Normal reward function
|
# Normal reward function
|
||||||
def graph_reward_fn(graphs, true_graphs=None, device=None):
|
from nas_201_api import NASBench201API as API
|
||||||
|
api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||||
|
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
|
||||||
rewards = []
|
rewards = []
|
||||||
|
if reward_model == 'swap':
|
||||||
|
import csv
|
||||||
|
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
header = next(reader)
|
||||||
|
data = [row for row in reader]
|
||||||
|
swap_scores = [float(row[0]) for row in data]
|
||||||
|
for graph in graphs:
|
||||||
|
node_tensor = graph[0]
|
||||||
|
node = node_tensor.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
def nodes_to_arch_str(nodes):
|
||||||
|
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||||
|
nodes_str = [num_to_op[node] for node in nodes]
|
||||||
|
arch_str = '|' + nodes_str[1] + '~0|+' + \
|
||||||
|
'|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
|
||||||
|
'|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'
|
||||||
|
return arch_str
|
||||||
|
|
||||||
|
arch_str = nodes_to_arch_str(node)
|
||||||
|
reward = swap_scores[api.query_index_by_arch(arch_str)]
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
reward = 1.0
|
reward = 1.0
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
|
||||||
# while samples_left_to_generate > 0:
|
while samples_left_to_generate > 0:
|
||||||
# print(f'samples left to generate: {samples_left_to_generate}/'
|
print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
# f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
# bs = 1 * cfg.train.batch_size
|
bs = 1 * cfg.train.batch_size
|
||||||
# to_generate = min(samples_left_to_generate, bs)
|
to_generate = min(samples_left_to_generate, bs)
|
||||||
# to_save = min(samples_left_to_save, bs)
|
to_save = min(samples_left_to_save, bs)
|
||||||
# chains_save = min(chains_left_to_save, bs)
|
chains_save = min(chains_left_to_save, bs)
|
||||||
# # batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||||
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
|
|
||||||
# cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||||
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
# samples = samples + cur_sample
|
samples = samples + cur_sample
|
||||||
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
||||||
|
|
||||||
# samples_with_log_probs.append((cur_sample, log_probs, reward))
|
samples_with_log_probs.append((cur_sample, log_probs, reward))
|
||||||
|
|
||||||
# all_ys.append(batch_y)
|
all_ys.append(batch_y)
|
||||||
# batch_id += to_generate
|
batch_id += to_generate
|
||||||
|
|
||||||
# samples_left_to_save -= to_save
|
samples_left_to_save -= to_save
|
||||||
# samples_left_to_generate -= to_generate
|
samples_left_to_generate -= to_generate
|
||||||
# chains_left_to_save -= chains_save
|
chains_left_to_save -= chains_save
|
||||||
|
|
||||||
# print(f"final Computing sampling metrics...")
|
print(f"final Computing sampling metrics...")
|
||||||
# graph_dit_model.sampling_metrics.reset()
|
graph_dit_model.sampling_metrics.reset()
|
||||||
# graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
||||||
# graph_dit_model.sampling_metrics.reset()
|
graph_dit_model.sampling_metrics.reset()
|
||||||
# print(f"Done.")
|
print(f"Done.")
|
||||||
|
|
||||||
# # save samples
|
# save samples
|
||||||
# print("Samples:")
|
print("Samples:")
|
||||||
# print(samples)
|
print(samples)
|
||||||
|
|
||||||
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
||||||
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
@ -333,61 +358,56 @@ def test(cfg: DictConfig):
|
|||||||
# log_probs = torch.cat(log_probs, dim=0)
|
# log_probs = torch.cat(log_probs, dim=0)
|
||||||
# print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
# print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
|
||||||
# old_log_probs = log_probs.clone()
|
# old_log_probs = log_probs.clone()
|
||||||
old_log_probs = None
|
|
||||||
while samples_left_to_generate > 0:
|
|
||||||
print(f'samples left to generate: {samples_left_to_generate}/'
|
|
||||||
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
|
||||||
bs = 1 * cfg.train.batch_size
|
|
||||||
to_generate = min(samples_left_to_generate, bs)
|
|
||||||
to_save = min(samples_left_to_save, bs)
|
|
||||||
chains_save = min(chains_left_to_save, bs)
|
|
||||||
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
|
||||||
|
|
||||||
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
# ===
|
||||||
|
|
||||||
# cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
# old_log_probs = None
|
||||||
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
# while samples_left_to_generate > 0:
|
||||||
# samples = samples + cur_sample
|
# print(f'samples left to generate: {samples_left_to_generate}/'
|
||||||
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
|
# f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
|
||||||
# advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
# bs = 1 * cfg.train.batch_size
|
||||||
with accelerator.accumulate(graph_dit_model):
|
# to_generate = min(samples_left_to_generate, bs)
|
||||||
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
# to_save = min(samples_left_to_save, bs)
|
||||||
new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
# chains_save = min(chains_left_to_save, bs)
|
||||||
samples = samples + new_samples
|
|
||||||
reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
|
# with accelerator.accumulate(graph_dit_model):
|
||||||
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
|
||||||
if old_log_probs is None:
|
# new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
|
||||||
old_log_probs = log_probs.clone()
|
# samples = samples + new_samples
|
||||||
ratio = torch.exp(log_probs - old_log_probs)
|
# reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
|
||||||
unclipped_loss = -advantages * ratio
|
# advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
|
||||||
clipped_loss = -advantages * torch.clamp(ratio,
|
# if old_log_probs is None:
|
||||||
1.0 - cfg.ppo.clip_param,
|
# old_log_probs = log_probs.clone()
|
||||||
1.0 + cfg.ppo.clip_param)
|
# ratio = torch.exp(log_probs - old_log_probs)
|
||||||
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
# unclipped_loss = -advantages * ratio
|
||||||
accelerator.backward(loss)
|
# clipped_loss = -advantages * torch.clamp(ratio,
|
||||||
optimizer.step()
|
# 1.0 - cfg.ppo.clip_param,
|
||||||
optimizer.zero_grad()
|
# 1.0 + cfg.ppo.clip_param)
|
||||||
|
# loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
|
||||||
|
# accelerator.backward(loss)
|
||||||
|
# optimizer.step()
|
||||||
|
# optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
samples_with_log_probs.append((new_samples, log_probs, reward))
|
# samples_with_log_probs.append((new_samples, log_probs, reward))
|
||||||
|
|
||||||
all_ys.append(batch_y)
|
# all_ys.append(batch_y)
|
||||||
batch_id += to_generate
|
# batch_id += to_generate
|
||||||
|
|
||||||
samples_left_to_save -= to_save
|
# samples_left_to_save -= to_save
|
||||||
samples_left_to_generate -= to_generate
|
# samples_left_to_generate -= to_generate
|
||||||
chains_left_to_save -= chains_save
|
# chains_left_to_save -= chains_save
|
||||||
# break
|
# # break
|
||||||
|
|
||||||
print(f"final Computing sampling metrics...")
|
# print(f"final Computing sampling metrics...")
|
||||||
graph_dit_model.sampling_metrics.reset()
|
# graph_dit_model.sampling_metrics.reset()
|
||||||
graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
# graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
|
||||||
graph_dit_model.sampling_metrics.reset()
|
# graph_dit_model.sampling_metrics.reset()
|
||||||
print(f"Done.")
|
# print(f"Done.")
|
||||||
|
|
||||||
# save samples
|
# # save samples
|
||||||
print("Samples:")
|
# print("Samples:")
|
||||||
print(samples)
|
# print(samples)
|
||||||
|
|
||||||
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
|
||||||
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
# samples, log_probs, rewards = samples_with_log_probs[perm]
|
||||||
|
Loading…
Reference in New Issue
Block a user