From 55ff19421d49ece71bff9aa5bbe8a5df278f1b85 Mon Sep 17 00:00:00 2001 From: mhz Date: Thu, 25 Jul 2024 22:09:03 +0200 Subject: [PATCH] add the idea of guidance --- graph_dit/datasets/dataset.py | 53 +++++++++++++++++++++++++++++++++++ graph_dit/diffusion_model.py | 36 +++++++++++++++++++++--- 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 5004f2d..43a6e26 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -8,6 +8,7 @@ import os import os.path as osp import pathlib import json +import random import torch import torch.nn.functional as F @@ -49,6 +50,9 @@ op_type = { 'none': 5, 'output': 6, } + +num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + class DataModule(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir @@ -676,6 +680,52 @@ class Dataset(InMemoryDataset): data_list = [] len_data = len(self.api) + def check_valid_graph(nodes, edges): + if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: + return False + if nodes[0] != 'input' or nodes[-1] != 'output': + return False + for i in range(0, len(nodes)): + if edges[i][i] == 1: + return False + for i in range(1, len(nodes) - 1): + if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': + return False + for i in range(0, len(nodes)): + for j in range(i, len(nodes)): + if edges[i, j] == 1 and nodes[j] == 'input': + return False + for i in range(0, len(nodes)): + for j in range(i, len(nodes)): + if edges[i, j] == 1 and nodes[i] == 'output': + return False + flag = 0 + for i in range(0,len(nodes)): + if edges[i,-1] == 1: + flag = 1 + break + if flag == 0: return False + return True + + def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): + nasbench_201_node_num = 8 + # random.seed(random_seed) + nodes_num = random.randint(min_nodes, max_nodes) + # print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}') + add_num = nodes_num - nasbench_201_node_num + # ori_nodes, ori_edges = parse_architecture_string(arch_str) + add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)] + # print(add_nodes) + nodes = ori_nodes[:-1] + add_nodes + ['output'] + edges = np.zeros((nodes_num , nodes_num)) + edges[:6, :6] = ori_edges[:6, :6] + edges[0:8, -1] = ori_edges[0:8 , -1] + for i in range(0, nodes_num): + for j in range(max(7,i + 1), nodes_num): + rand = random.random() + if rand < random_ratio: + edges[i, j] = 1 + return nodes, edges def graph_to_graph_data(graph): ops = graph[1] @@ -746,6 +796,9 @@ class Dataset(InMemoryDataset): }) data = graph_to_graph_data((adj_matrix, ops)) data_list.append(data) + + # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8, random_ratio=0.5) + # data_list.append(graph_to_graph_data((new_adj, new_ops))) pbar.update(1) for graph in graph_list: diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index a5eab3d..7da917c 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -134,7 +134,7 @@ class Graph_DiT(pl.LightningModule): loss = self.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, log=i % self.log_every_steps == 0) - + # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') self.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, log=i % self.log_every_steps == 0) self.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) @@ -601,7 +601,8 @@ class Graph_DiT(pl.LightningModule): # Normalize predictions pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 - pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 + pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 + # gradient # Retrieve transitions matrix Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) @@ -629,25 +630,52 @@ class Graph_DiT(pl.LightningModule): prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) return prob_X, prob_E - + # diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t) + # with condition = P_t(G_{t-1} |G_t, C) + # with condition = P_t(A_{t-1} |A_t, y) prob_X, prob_E = get_prob(noisy_data) ### Guidance if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True) - prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale + prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10) prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-10) + assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) + # sample multiple times and get the best score arch... + + sample_num = 100 + best_arch = None + best_score = -1e8 + + for i in range(sample_num): + sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) + score = get_score(sampled_s) + if score > best_score: + best_score = score + best_arch = sampled_s + X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() + # NASWOT score + target_score = torch.tensor([3000.0]) + + # compute loss mse(cur_score - target_score) + + # loss backward = gradient + + # get prob.X, prob_E gradient + + # update prob.X prob_E with using gradient + assert (E_s == torch.transpose(E_s, 1, 2)).all() assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)