diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 143f65c..4d017ac 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -39,6 +39,16 @@ op_to_atom = { 'none': 'S', # Sulfur for no operation 'output': 'He' # Helium for output } + +op_type = { + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'output': 5, + 'none': 6, + 'input': 7 +} class DataModule(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir @@ -343,6 +353,121 @@ class DataModule_original(AbstractDataModule): def test_dataloader(self): return self.test_loader +def new_graphs_to_json(graphs, filename): + source_name = "nasbench-201" + num_graph = len(graphs) + + node_name_list = [] + node_count_list = [] + + for op_name in op_type: + node_name_list.append(op_name) + node_count_list.append(0) + + node_name_list.append('*') + node_count_list.append(0) + n_nodes_per_graph = [0] * num_graph + edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + valencies = [0] * (len(op_type) + 1) + transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) + + n_node_list = [] + n_edge_list = [] + + for graph in graphs: + ops = graph[1] + adj = graph[0] + + n_node = len(ops) + n_edge = len(ops) + n_node_list.append(n_node) + n_edge_list.append(n_edge) + + n_nodes_per_graph[n_node] += 1 + cur_node_count_arr = np.zeros(len(op_type) + 1) + + for op in ops: + node = op + if node == '*': + node_count_list[-1] += 1 + cur_node_count_arr[-1] += 1 + else: + node_count_list[op_type[node]] += 1 + cur_node_count_arr[op_type[node]] += 1 + try: + valencies[int(op_type[node])] += 1 + except: + print('int(op_type[node])', int(op_type[node])) + + transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) + for i in range(n_node): + for j in range(n_node): + if i == j or adj[i][j] == 0: + continue + start_node, end_node = i, j + + start_index = op_type[ops[start_node]] + end_index = op_type[ops[end_node]] + bond_index = 1 + edge_count_list[bond_index] += 2 + + transition_E[start_index, end_index, bond_index] += 2 + transition_E[end_index, start_index, bond_index] += 2 + transition_E_temp[start_index, end_index, bond_index] += 2 + transition_E_temp[end_index, start_index, bond_index] += 2 + + edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 + cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 + print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") + cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 + transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) + assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 + + n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph) + n_nodes_per_graph = n_nodes_per_graph.tolist()[:51] + + node_count_list = np.array(node_count_list) / np.sum(node_count_list) + print('processed meta info: ------', filename, '------') + print('len node_count_list', len(node_count_list)) + print('len node_name_list', len(node_name_list)) + active_nodes = np.array(node_name_list)[node_count_list > 0] + active_nodes = active_nodes.tolist() + node_count_list = node_count_list.tolist() + + edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list) + edge_count_list = edge_count_list.tolist() + valencies = np.array(valencies) / np.sum(valencies) + valencies = valencies.tolist() + + no_edge = np.sum(transition_E, axis=-1) == 0 + first_elt = transition_E[:, :, 0] + first_elt[no_edge] = 1 + transition_E[:, :, 0] = first_elt + + transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) + + meta_dict = { + 'source': source_name, + 'num_graph': num_graph, + 'n_nodes_per_graph': n_nodes_per_graph, + 'max_n_nodes': max(n_node_list), + 'max_n_edges': max(n_edge_list), + 'node_type_list': node_count_list, + 'edge_type_list': edge_count_list, + 'valencies': valencies, + 'active_nodes': active_nodes, + 'num_active_nodes': len(active_nodes), + 'transition_E': transition_E.tolist(), + } + + with open(f'{filename}.meta.json', 'w') as f: + json.dump(meta_dict, f) + + return meta_dict + + + + def graphs_to_json(graphs, filename): bonds = { 'nor_conv_1x1': 1, @@ -490,7 +615,7 @@ def graphs_to_json(graphs, filename): 'atom_type_dist': atom_count_list, 'bond_type_dist': bond_count_list, 'valencies': valencies, - 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], + 'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), 'transition_E': transition_E.tolist(), } @@ -503,10 +628,10 @@ class Dataset(InMemoryDataset): self.target_prop = target_prop source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.source = source - super().__init__(root, transform, pre_transform, pre_filter) - print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt self.api = API(source) # Initialize NAS-Bench-201 API print('API loaded') + super().__init__(root, transform, pre_transform, pre_filter) + print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt print('Dataset initialized') self.data, self.slices = torch.load(self.processed_paths[0]) self.data.edge_attr = self.data.edge_attr.squeeze() @@ -732,30 +857,35 @@ class DataInfos(AbstractDatasetInfos): arch_info = self.api.query_meta_info_by_index(i) nodes, edges = parse_architecture_string(arch_info.arch_str) adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) - if i < 5: - print("Adjacency Matrix:") - print(adj_matrix) - print("Operations List:") - print(ops) + # if i < 5: + # print("Adjacency Matrix:") + # print(adj_matrix) + # print("Operations List:") + # print(ops) for op in ops: if op not in ops_type: ops_type[op] = len(ops_type) len_ops.add(len(ops)) graphs.append((adj_matrix, ops)) - meta_dict = graphs_to_json(graphs, 'nasbench-201') + # check first five graphs + for i in range(5): + print(f'graph {i} : {graphs[i]}') + print(f'ops_type: {ops_type}') + + meta_dict = new_graphs_to_json(graphs, 'nasbench-201') self.base_path = base_path - self.active_atoms = meta_dict['active_atoms'] - self.max_n_nodes = meta_dict['max_node'] - self.original_max_n_nodes = meta_dict['max_node'] - self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) - self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) + self.active_nodes = meta_dict['active_nodes'] + self.max_n_nodes = meta_dict['max_n_nodes'] + self.original_max_n_nodes = meta_dict['max_n_nodes'] + self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph']) + self.edge_types = torch.Tensor(meta_dict['edge_type_dist']) self.transition_E = torch.Tensor(meta_dict['transition_E']) - self.atom_decoder = meta_dict['active_atoms'] - node_types = torch.Tensor(meta_dict['atom_type_dist']) + self.node_decoder = meta_dict['active_nodes'] + node_types = torch.Tensor(meta_dict['node_type_dist']) active_index = (node_types > 0).nonzero().squeeze() - self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] + self.node_types = torch.Tensor(meta_dict['node_type_dist'])[active_index] self.nodes_dist = DistributionNodes(self.n_nodes) self.active_index = active_index