try to update reward func

This commit is contained in:
mhz 2024-09-14 23:56:36 +02:00
parent 2ac17caa3c
commit 94fe13756f
2 changed files with 96 additions and 75 deletions

View File

@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule):
assert (E == torch.transpose(E, 1, 2)).all() assert (E == torch.transpose(E, 1, 2)).all()
# total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.

View File

@ -281,45 +281,70 @@ def test(cfg: DictConfig):
num_examples = test_y_collection.size(0) num_examples = test_y_collection.size(0)
# Normal reward function # Normal reward function
def graph_reward_fn(graphs, true_graphs=None, device=None): from nas_201_api import NASBench201API as API
api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth')
def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
rewards = [] rewards = []
if reward_model == 'swap':
import csv
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
reader = csv.reader(f)
header = next(reader)
data = [row for row in reader]
swap_scores = [float(row[0]) for row in data]
for graph in graphs:
node_tensor = graph[0]
node = node_tensor.cpu().numpy().tolist()
def nodes_to_arch_str(nodes):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
nodes_str = [num_to_op[node] for node in nodes]
arch_str = '|' + nodes_str[1] + '~0|+' + \
'|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
'|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'
return arch_str
arch_str = nodes_to_arch_str(node)
reward = swap_scores[api.query_index_by_arch(arch_str)]
rewards.append(reward)
for graph in graphs: for graph in graphs:
reward = 1.0 reward = 1.0
rewards.append(reward) rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
# while samples_left_to_generate > 0: while samples_left_to_generate > 0:
# print(f'samples left to generate: {samples_left_to_generate}/' print(f'samples left to generate: {samples_left_to_generate}/'
# f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
# bs = 1 * cfg.train.batch_size bs = 1 * cfg.train.batch_size
# to_generate = min(samples_left_to_generate, bs) to_generate = min(samples_left_to_generate, bs)
# to_save = min(samples_left_to_save, bs) to_save = min(samples_left_to_save, bs)
# chains_save = min(chains_left_to_save, bs) chains_save = min(chains_left_to_save, bs)
# # batch_y = test_y_collection[batch_id : batch_id + to_generate] # batch_y = test_y_collection[batch_id : batch_id + to_generate]
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
# cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
# samples = samples + cur_sample samples = samples + cur_sample
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
# samples_with_log_probs.append((cur_sample, log_probs, reward)) samples_with_log_probs.append((cur_sample, log_probs, reward))
# all_ys.append(batch_y) all_ys.append(batch_y)
# batch_id += to_generate batch_id += to_generate
# samples_left_to_save -= to_save samples_left_to_save -= to_save
# samples_left_to_generate -= to_generate samples_left_to_generate -= to_generate
# chains_left_to_save -= chains_save chains_left_to_save -= chains_save
# print(f"final Computing sampling metrics...") print(f"final Computing sampling metrics...")
# graph_dit_model.sampling_metrics.reset() graph_dit_model.sampling_metrics.reset()
# graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
# graph_dit_model.sampling_metrics.reset() graph_dit_model.sampling_metrics.reset()
# print(f"Done.") print(f"Done.")
# # save samples # save samples
# print("Samples:") print("Samples:")
# print(samples) print(samples)
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
# samples, log_probs, rewards = samples_with_log_probs[perm] # samples, log_probs, rewards = samples_with_log_probs[perm]
@ -333,61 +358,56 @@ def test(cfg: DictConfig):
# log_probs = torch.cat(log_probs, dim=0) # log_probs = torch.cat(log_probs, dim=0)
# print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
# old_log_probs = log_probs.clone() # old_log_probs = log_probs.clone()
old_log_probs = None
while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
bs = 1 * cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
# batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) # ===
# cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, # old_log_probs = None
# keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) # while samples_left_to_generate > 0:
# samples = samples + cur_sample # print(f'samples left to generate: {samples_left_to_generate}/'
# reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) # f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
# advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) # bs = 1 * cfg.train.batch_size
with accelerator.accumulate(graph_dit_model): # to_generate = min(samples_left_to_generate, bs)
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) # to_save = min(samples_left_to_save, bs)
new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) # chains_save = min(chains_left_to_save, bs)
samples = samples + new_samples
reward = graph_reward_fn(new_samples, device=graph_dit_model.device) # with accelerator.accumulate(graph_dit_model):
advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
if old_log_probs is None: # new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
old_log_probs = log_probs.clone() # samples = samples + new_samples
ratio = torch.exp(log_probs - old_log_probs) # reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
unclipped_loss = -advantages * ratio # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
clipped_loss = -advantages * torch.clamp(ratio, # if old_log_probs is None:
1.0 - cfg.ppo.clip_param, # old_log_probs = log_probs.clone()
1.0 + cfg.ppo.clip_param) # ratio = torch.exp(log_probs - old_log_probs)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) # unclipped_loss = -advantages * ratio
accelerator.backward(loss) # clipped_loss = -advantages * torch.clamp(ratio,
optimizer.step() # 1.0 - cfg.ppo.clip_param,
optimizer.zero_grad() # 1.0 + cfg.ppo.clip_param)
# loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
# accelerator.backward(loss)
# optimizer.step()
# optimizer.zero_grad()
samples_with_log_probs.append((new_samples, log_probs, reward)) # samples_with_log_probs.append((new_samples, log_probs, reward))
all_ys.append(batch_y) # all_ys.append(batch_y)
batch_id += to_generate # batch_id += to_generate
samples_left_to_save -= to_save # samples_left_to_save -= to_save
samples_left_to_generate -= to_generate # samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save # chains_left_to_save -= chains_save
# break # # break
print(f"final Computing sampling metrics...") # print(f"final Computing sampling metrics...")
graph_dit_model.sampling_metrics.reset() # graph_dit_model.sampling_metrics.reset()
graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
graph_dit_model.sampling_metrics.reset() # graph_dit_model.sampling_metrics.reset()
print(f"Done.") # print(f"Done.")
# save samples # # save samples
print("Samples:") # print("Samples:")
print(samples) # print(samples)
# perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
# samples, log_probs, rewards = samples_with_log_probs[perm] # samples, log_probs, rewards = samples_with_log_probs[perm]