From 7149b49a395a1cc78ca3c4c98461e5d765c9baea Mon Sep 17 00:00:00 2001 From: mhz Date: Tue, 13 Aug 2024 09:42:51 +0200 Subject: [PATCH] update the flex data code --- graph_dit/datasets/dataset.py | 63 ++++++++--- graph_dit/diffusion_model.py | 204 +++++++++++++++++----------------- 2 files changed, 152 insertions(+), 115 deletions(-) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index da4972e..8fe95ad 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -70,7 +70,7 @@ class DataModule(AbstractDataModule): # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # except NameError: # base_path = pathlib.Path(os.getcwd()).parent[2] - base_path = '/home/stud/hanzhang/nasbenchDiT' + base_path = '/nfs/data3/hanzhang/nasbenchDiT' root_path = os.path.join(base_path, self.datadir) self.root_path = root_path @@ -408,6 +408,7 @@ def new_graphs_to_json(graphs, filename): adj = graph[0] n_node = len(ops) + print(n_node) n_edge = len(ops) n_node_list.append(n_node) n_edge_list.append(n_edge) @@ -489,7 +490,7 @@ def new_graphs_to_json(graphs, filename): 'transition_E': transition_E.tolist(), } - with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: + with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: json.dump(meta_dict, f) return meta_dict @@ -655,7 +656,7 @@ def graphs_to_json(graphs, filename): class Dataset(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop - source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.source = source # self.api = API(source) # Initialize NAS-Bench-201 API # print('API loaded') @@ -676,7 +677,7 @@ class Dataset(InMemoryDataset): return [f'{self.source}.pt'] def process(self): - source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' # self.api = API(source) data_list = [] @@ -712,6 +713,7 @@ class Dataset(InMemoryDataset): def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): # print(ori_nodes) # print(ori_edges) + ori_edges = np.array(ori_edges) # ori_nodes = np.array(ori_nodes) nasbench_201_node_num = 8 @@ -720,8 +722,13 @@ class Dataset(InMemoryDataset): # print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}') add_num = nodes_num - nasbench_201_node_num # ori_nodes, ori_edges = parse_architecture_string(arch_str) - add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)] + add_nodes = [] + print(f'add_num: {add_num}') + for i in range(add_num): + add_nodes.append(random.choice(num_to_op[1:-1])) # print(add_nodes) + print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}') + print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}') nodes = ori_nodes[:-1] + add_nodes + ['output'] edges = np.zeros((nodes_num , nodes_num)) edges[:6, :6] = ori_edges[:6, :6] @@ -731,6 +738,11 @@ class Dataset(InMemoryDataset): rand = random.random() if rand < random_ratio: edges[i, j] = 1 + if nodes_num < max_nodes: + edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0) + while len(nodes) < max_nodes: + nodes.append('none') + print(f'edges size: {edges.shape}, nodes size: {len(nodes)}') return edges,nodes def get_nasbench_201_val(idx): @@ -766,10 +778,12 @@ class Dataset(InMemoryDataset): with tqdm(total = len_data) as pbar: active_nodes = set() - file_path = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' + file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' with open(file_path, 'r') as f: graph_list = json.load(f) i = 0 + flex_graph_list = [] + flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' for graph in graph_list: # arch_info = self.api.query_meta_info_by_index(i) # results = self.api.query_by_index(i, 'cifar100') @@ -784,6 +798,16 @@ class Dataset(InMemoryDataset): active_nodes.add(op) data = graph_to_graph_data((adj_matrix, ops)) + # with open(flex_graph_path, 'a') as f: + # flex_graph = { + # 'adj_matrix': adj_matrix, + # 'ops': ops, + # } + # json.dump(flex_graph, f) + flex_graph_list.append({ + 'adj_matrix':adj_matrix, + 'ops': ops, + }) if i < 3: print(f"i={i}, data={data}") with open(f'{i}.json', 'w') as f: @@ -792,7 +816,17 @@ class Dataset(InMemoryDataset): f.write(str(data.edge_attr)) data_list.append(data) - new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=8, random_ratio=0.5) + new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5) + flex_graph_list.append({ + 'adj_matrix':new_adj.tolist(), + 'ops': new_ops, + }) + # with open(flex_graph_path, 'w') as f: + # flex_graph = { + # 'adj_matrix': new_adj.tolist(), + # 'ops': new_ops, + # } + # json.dump(flex_graph, f) data_list.append(graph_to_graph_data((new_adj, new_ops))) # graph_list.append({ @@ -838,6 +872,8 @@ class Dataset(InMemoryDataset): graph['ops'] = ops with open(f'nasbench-201-graph.json', 'w') as f: json.dump(graph_list, f) + with open(flex_graph_path, 'w') as f: + json.dump(flex_graph_list, f) torch.save(self.collate(data_list), self.processed_paths[0]) @@ -1034,8 +1070,8 @@ def parse_architecture_string(arch_str, padding=0): assert idx == steps_coding[cont] cont += 1 nodes.append(n) - ori_nodes = nodes.copy() nodes.append('output') # Add output node + ori_nodes = nodes.copy() if padding > 0: for i in range(padding): nodes.append('none') @@ -1048,7 +1084,7 @@ def parse_architecture_string(arch_str, padding=0): # print(nodes) # print(adj_mat) # print(len(adj_mat)) - + # print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}') return nodes, adj_mat, ori_nodes, ori_adj_mat def create_adj_matrix_and_ops(nodes, edges): @@ -1091,6 +1127,7 @@ class DataInfos(AbstractDatasetInfos): adj_ops_pairs = [] for item in data: + print(item) adj_matrix = np.array(item['adj_matrix']) ops = item['ops'] ops = [op_type[op] for op in ops] @@ -1111,12 +1148,12 @@ class DataInfos(AbstractDatasetInfos): # ops_type[op] = len(ops_type) # len_ops.add(len(ops)) # graphs.append((adj_matrix, ops)) - graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') + graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') # check first five graphs for i in range(5): print(f'graph {i} : {graphs[i]}') - print(f'ops_type: {ops_type}') + # print(f'ops_type: {ops_type}') meta_dict = new_graphs_to_json(graphs, 'nasbench-201') self.base_path = base_path @@ -1325,11 +1362,11 @@ def compute_meta(root, source_name, train_index, test_index): 'transition_E': tansition_E.tolist(), } - with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: + with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: json.dump(meta_dict, f) return meta_dict if __name__ == "__main__": - dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) + dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None) diff --git a/graph_dit/diffusion_model.py b/graph_dit/diffusion_model.py index d286c71..79cee7d 100644 --- a/graph_dit/diffusion_model.py +++ b/graph_dit/diffusion_model.py @@ -3,9 +3,9 @@ import torch.nn.functional as F import pytorch_lightning as pl import time import os -from naswot.score_networks import get_nasbench201_nodes_score -from naswot import nasspace -from naswot import datasets +# from naswot.score_networks import get_nasbench201_nodes_score +# from naswot import nasspace +# from naswot import datasets from models.transformer import Denoiser from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition @@ -41,7 +41,7 @@ class Graph_DiT(pl.LightningModule): self.args.batch_size = 128 self.args.GPU = '0' self.args.dataset = 'cifar10-valid' - self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' + self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.args.data_loc = '../cifardata/' self.args.seed = 777 self.args.init = '' @@ -59,10 +59,10 @@ class Graph_DiT(pl.LightningModule): if 'valid' in self.args.dataset: self.args.dataset = self.args.dataset.replace('-valid', '') print('graph_dit starts to get searchspace of nasbench201') - self.searchspace = nasspace.get_search_space(self.args) + # self.searchspace = nasspace.get_search_space(self.args) print('searchspace of nasbench201 is obtained') print('graphdit starts to get train_loader') - self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args) + # self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args) print('train_loader is obtained') self.cfg = cfg @@ -162,7 +162,7 @@ class Graph_DiT(pl.LightningModule): return pred def training_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) @@ -222,7 +222,7 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def validation_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) dense_data = dense_data.mask(node_mask, collapse=False) @@ -315,7 +315,7 @@ class Graph_DiT(pl.LightningModule): @torch.no_grad() def test_step(self, data, i): - data_x = F.one_hot(data.x, num_classes=8).float()[:, self.active_index] + data_x = F.one_hot(data.x, num_classes=12).float()[:, self.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, self.max_n_nodes) @@ -686,120 +686,120 @@ class Graph_DiT(pl.LightningModule): assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() - # sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) + sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) # sample multiple times and get the best score arch... - num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] - op_type = { - 'input': 0, - 'nor_conv_1x1': 1, - 'nor_conv_3x3': 2, - 'avg_pool_3x3': 3, - 'skip_connect': 4, - 'none': 5, - 'output': 6, - } - def check_valid_graph(nodes, edges): - nodes = [num_to_op[i] for i in nodes] - if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: - return False - if nodes[0] != 'input' or nodes[-1] != 'output': - return False - for i in range(0, len(nodes)): - if edges[i][i] == 1: - return False - for i in range(1, len(nodes) - 1): - if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': - return False - for i in range(0, len(nodes)): - for j in range(i, len(nodes)): - if edges[i, j] == 1 and nodes[j] == 'input': - return False - for i in range(0, len(nodes)): - for j in range(i, len(nodes)): - if edges[i, j] == 1 and nodes[i] == 'output': - return False - flag = 0 - for i in range(0,len(nodes)): - if edges[i,-1] == 1: - flag = 1 - break - if flag == 0: return False - return True + # num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + # op_type = { + # 'input': 0, + # 'nor_conv_1x1': 1, + # 'nor_conv_3x3': 2, + # 'avg_pool_3x3': 3, + # 'skip_connect': 4, + # 'none': 5, + # 'output': 6, + # } + # def check_valid_graph(nodes, edges): + # nodes = [num_to_op[i] for i in nodes] + # if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: + # return False + # if nodes[0] != 'input' or nodes[-1] != 'output': + # return False + # for i in range(0, len(nodes)): + # if edges[i][i] == 1: + # return False + # for i in range(1, len(nodes) - 1): + # if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': + # return False + # for i in range(0, len(nodes)): + # for j in range(i, len(nodes)): + # if edges[i, j] == 1 and nodes[j] == 'input': + # return False + # for i in range(0, len(nodes)): + # for j in range(i, len(nodes)): + # if edges[i, j] == 1 and nodes[i] == 'output': + # return False + # flag = 0 + # for i in range(0,len(nodes)): + # if edges[i,-1] == 1: + # flag = 1 + # break + # if flag == 0: return False + # return True - class Args: - pass + # class Args: + # pass - def get_score(sampled_s): - x_list = sampled_s.X.unbind(dim=0) - e_list = sampled_s.E.unbind(dim=0) - valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))] - from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score - score = [] + # def get_score(sampled_s): + # x_list = sampled_s.X.unbind(dim=0) + # e_list = sampled_s.E.unbind(dim=0) + # valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))] + # from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score + # score = [] - for i in range(len(x_list)): - if valid_rlt[i]: - nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()] - # edges = e_list[i].cpu().numpy() - score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args)) - else: - score.append(-1) - return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device) + # for i in range(len(x_list)): + # if valid_rlt[i]: + # nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()] + # # edges = e_list[i].cpu().numpy() + # score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args)) + # else: + # score.append(-1) + # return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device) - sample_num = 10 - best_arch = None - best_score_int = -1e8 - score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8 + # sample_num = 10 + # best_arch = None + # best_score_int = -1e8 + # score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8 - for i in range(sample_num): - sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) - score = get_score(sampled_s) - print(f'score: {score}') - print(f'score.shape: {score.shape}') - print(f'torch.sum(score): {torch.sum(score)}') - sum_score = torch.sum(score) - print(f'sum_score: {sum_score}') - if sum_score > best_score_int: - best_score_int = sum_score - best_score = score - best_arch = sampled_s + # for i in range(sample_num): + # sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) + # score = get_score(sampled_s) + # print(f'score: {score}') + # print(f'score.shape: {score.shape}') + # print(f'torch.sum(score): {torch.sum(score)}') + # sum_score = torch.sum(score) + # print(f'sum_score: {sum_score}') + # if sum_score > best_score_int: + # best_score_int = sum_score + # best_score = score + # best_arch = sampled_s # print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}') # best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) - # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() - # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() - print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2 + X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() + E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() + # print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2 - print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}') - X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float() - E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float() - print(f'X_s: {X_s}, E_s: {E_s}') + # print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}') + # X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float() + # E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float() + # print(f'X_s: {X_s}, E_s: {E_s}') - # NASWOT score - target_score = torch.ones(100, requires_grad=True) * 2000.0 - target_score = target_score.to(X_s.device) + # # NASWOT score + # target_score = torch.ones(100, requires_grad=True) * 2000.0 + # target_score = target_score.to(X_s.device) - # compute loss mse(cur_score - target_score) - mse_loss = torch.nn.MSELoss() - print(f'best_score: {best_score.shape}, target_score: {target_score.shape}') - print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}') - loss = mse_loss(best_score, target_score) - loss.backward(retain_graph=True) + # # compute loss mse(cur_score - target_score) + # mse_loss = torch.nn.MSELoss() + # print(f'best_score: {best_score.shape}, target_score: {target_score.shape}') + # print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}') + # loss = mse_loss(best_score, target_score) + # loss.backward(retain_graph=True) # loss backward = gradient # get prob.X, prob_E gradient - x_grad = pred.X.grad - e_grad = pred.E.grad + # x_grad = pred.X.grad + # e_grad = pred.E.grad - beta_ratio = 0.5 - # x_current = pred.X - beta_ratio * x_grad - # e_current = pred.E - beta_ratio * e_grad - E_s = pred.X - beta_ratio * x_grad - X_s = pred.E - beta_ratio * e_grad + # beta_ratio = 0.5 + # # x_current = pred.X - beta_ratio * x_grad + # # e_current = pred.E - beta_ratio * e_grad + # E_s = pred.X - beta_ratio * x_grad + # X_s = pred.E - beta_ratio * e_grad # update prob.X prob_E with using gradient