From 94fe13756f31f789ba424786fd20a1264919c6e8 Mon Sep 17 00:00:00 2001
From: mhz <cxyoz@outlook.com>
Date: Sat, 14 Sep 2024 23:56:36 +0200
Subject: [PATCH] try to update reward func

---
 graph_dit/diffusion_model.py |   1 +
 graph_dit/main.py            | 170 +++++++++++++++++++----------------
 2 files changed, 96 insertions(+), 75 deletions(-)

diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py
index e056c51..310131c 100644
--- a/graph_dit/diffusion_model.py
+++ b/graph_dit/diffusion_model.py
@@ -601,6 +601,7 @@ class Graph_DiT(pl.LightningModule):
 
         assert (E == torch.transpose(E, 1, 2)).all()
 
+        # total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
         total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device)
 
         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
diff --git a/graph_dit/main.py b/graph_dit/main.py
index 33fd0dd..9901522 100644
--- a/graph_dit/main.py
+++ b/graph_dit/main.py
@@ -281,45 +281,70 @@ def test(cfg: DictConfig):
         num_examples = test_y_collection.size(0)
     
     # Normal reward function
-    def graph_reward_fn(graphs, true_graphs=None, device=None):
+    from nas_201_api import NASBench201API as API
+    api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth')
+    def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'):
         rewards = []
+        if reward_model == 'swap':
+            import csv
+            with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
+                reader = csv.reader(f)
+                header = next(reader)
+                data = [row for row in reader]
+                swap_scores = [float(row[0]) for row in data]
+                for graph in graphs:
+                    node_tensor = graph[0]
+                    node = node_tensor.cpu().numpy().tolist()
+
+                    def nodes_to_arch_str(nodes):
+                        num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
+                        nodes_str = [num_to_op[node] for node in nodes]
+                        arch_str = '|' + nodes_str[1] + '~0|+' + \
+                                '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\
+                                '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|' 
+                        return arch_str
+                    
+                    arch_str = nodes_to_arch_str(node)
+                    reward = swap_scores[api.query_index_by_arch(arch_str)]
+                    rewards.append(reward)
+                
         for graph in graphs:
             reward = 1.0
             rewards.append(reward)
         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
-    # while samples_left_to_generate > 0:
-    #     print(f'samples left to generate: {samples_left_to_generate}/'
-    #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
-    #     bs = 1 * cfg.train.batch_size
-    #     to_generate = min(samples_left_to_generate, bs)
-    #     to_save = min(samples_left_to_save, bs)
-    #     chains_save = min(chains_left_to_save, bs)
-    #     # batch_y = test_y_collection[batch_id : batch_id + to_generate]
-    #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
+    while samples_left_to_generate > 0:
+        print(f'samples left to generate: {samples_left_to_generate}/'
+            f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
+        bs = 1 * cfg.train.batch_size
+        to_generate = min(samples_left_to_generate, bs)
+        to_save = min(samples_left_to_save, bs)
+        chains_save = min(chains_left_to_save, bs)
+        # batch_y = test_y_collection[batch_id : batch_id + to_generate]
+        batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
 
-    #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
-    #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
-    #     samples = samples + cur_sample
-    #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
+        cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
+                                        keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
+        samples = samples + cur_sample
+        reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
 
-    #     samples_with_log_probs.append((cur_sample, log_probs, reward))
+        samples_with_log_probs.append((cur_sample, log_probs, reward))
         
-    #     all_ys.append(batch_y)
-    #     batch_id += to_generate
+        all_ys.append(batch_y)
+        batch_id += to_generate
 
-    #     samples_left_to_save -= to_save
-    #     samples_left_to_generate -= to_generate
-    #     chains_left_to_save -= chains_save
+        samples_left_to_save -= to_save
+        samples_left_to_generate -= to_generate
+        chains_left_to_save -= chains_save
         
-    # print(f"final Computing sampling metrics...")
-    # graph_dit_model.sampling_metrics.reset()
-    # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
-    # graph_dit_model.sampling_metrics.reset()
-    # print(f"Done.")
+    print(f"final Computing sampling metrics...")
+    graph_dit_model.sampling_metrics.reset()
+    graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
+    graph_dit_model.sampling_metrics.reset()
+    print(f"Done.")
 
-    # # save samples
-    # print("Samples:")
-    # print(samples)
+    # save samples
+    print("Samples:")
+    print(samples)
 
     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
     # samples, log_probs, rewards = samples_with_log_probs[perm]
@@ -333,61 +358,56 @@ def test(cfg: DictConfig):
     # log_probs = torch.cat(log_probs, dim=0)
     # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1])
     # old_log_probs = log_probs.clone()
-    old_log_probs = None
-    while samples_left_to_generate > 0:
-        print(f'samples left to generate: {samples_left_to_generate}/'
-            f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
-        bs = 1 * cfg.train.batch_size
-        to_generate = min(samples_left_to_generate, bs)
-        to_save = min(samples_left_to_save, bs)
-        chains_save = min(chains_left_to_save, bs)
-        # batch_y = test_y_collection[batch_id : batch_id + to_generate]
 
-        # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
+    # ===
 
-        # cur_sample, old_log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
-                                        # keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
-        # samples = samples + cur_sample
-        # reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
-        # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
-        with accelerator.accumulate(graph_dit_model):
-            batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
-            new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
-            samples = samples + new_samples
-            reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
-            advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
-            if old_log_probs is None:
-                old_log_probs = log_probs.clone()
-            ratio = torch.exp(log_probs - old_log_probs)
-            unclipped_loss = -advantages * ratio
-            clipped_loss = -advantages * torch.clamp(ratio,
-                            1.0 - cfg.ppo.clip_param,
-                            1.0 + cfg.ppo.clip_param)
-            loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
-            accelerator.backward(loss)
-            optimizer.step()
-            optimizer.zero_grad()
+    # old_log_probs = None
+    # while samples_left_to_generate > 0:
+    #     print(f'samples left to generate: {samples_left_to_generate}/'
+    #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
+    #     bs = 1 * cfg.train.batch_size
+    #     to_generate = min(samples_left_to_generate, bs)
+    #     to_save = min(samples_left_to_save, bs)
+    #     chains_save = min(chains_left_to_save, bs)
+
+    #     with accelerator.accumulate(graph_dit_model):
+    #         batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
+    #         new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
+    #         samples = samples + new_samples
+    #         reward = graph_reward_fn(new_samples, device=graph_dit_model.device)
+    #         advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6)
+    #         if old_log_probs is None:
+    #             old_log_probs = log_probs.clone()
+    #         ratio = torch.exp(log_probs - old_log_probs)
+    #         unclipped_loss = -advantages * ratio
+    #         clipped_loss = -advantages * torch.clamp(ratio,
+    #                         1.0 - cfg.ppo.clip_param,
+    #                         1.0 + cfg.ppo.clip_param)
+    #         loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
+    #         accelerator.backward(loss)
+    #         optimizer.step()
+    #         optimizer.zero_grad()
 
 
-        samples_with_log_probs.append((new_samples, log_probs, reward))
+    #     samples_with_log_probs.append((new_samples, log_probs, reward))
         
-        all_ys.append(batch_y)
-        batch_id += to_generate
+    #     all_ys.append(batch_y)
+    #     batch_id += to_generate
 
-        samples_left_to_save -= to_save
-        samples_left_to_generate -= to_generate
-        chains_left_to_save -= chains_save
-        # break
+    #     samples_left_to_save -= to_save
+    #     samples_left_to_generate -= to_generate
+    #     chains_left_to_save -= chains_save
+    #     # break
         
-    print(f"final Computing sampling metrics...")
-    graph_dit_model.sampling_metrics.reset()
-    graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
-    graph_dit_model.sampling_metrics.reset()
-    print(f"Done.")
+    # print(f"final Computing sampling metrics...")
+    # graph_dit_model.sampling_metrics.reset()
+    # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
+    # graph_dit_model.sampling_metrics.reset()
+    # print(f"Done.")
 
-    # save samples
-    print("Samples:")
-    print(samples)
+    # # save samples
+    # print("Samples:")
+    # print(samples)
 
     # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
     # samples, log_probs, rewards = samples_with_log_probs[perm]