diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index dbd64f2..8f21f7e 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename): node_name_list = [] node_count_list = [] + node_name_list.append('*') for op_name in op_type: node_name_list.append(op_name) node_count_list.append(0) - node_name_list.append('*') node_count_list.append(0) n_nodes_per_graph = [0] * num_graph - edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + edge_count_list = [0, 0] valencies = [0] * (len(op_type) + 1) transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) @@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename): for op in ops: node = op - if node == '*': - node_count_list[-1] += 1 - cur_node_count_arr[-1] += 1 - else: - node_count_list[op_type[node]] += 1 - cur_node_count_arr[op_type[node]] += 1 - try: - valencies[int(op_type[node])] += 1 - except: - print('int(op_type[node])', int(op_type[node])) + # if node == '*': + # node_count_list[-1] += 1 + # cur_node_count_arr[-1] += 1 + # else: + node_count_list[node] += 1 + cur_node_count_arr[node] += 1 + try: + valencies[node] += 1 + except: + print('int(op_type[node])', int(node)) transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) for i in range(n_node): @@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename): continue start_node, end_node = i, j - start_index = op_type[ops[start_node]] - end_index = op_type[ops[end_node]] + start_index = ops[start_node] + end_index = ops[end_node] bond_index = 1 edge_count_list[bond_index] += 2 @@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename): edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 - print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") + # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 @@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename): 'transition_E': transition_E.tolist(), } - with open(f'{filename}.meta.json', 'w') as f: + with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: json.dump(meta_dict, f) return meta_dict @@ -683,15 +683,41 @@ class Dataset(InMemoryDataset): active_nodes = set() for i in range(len_data): arch_info = self.api.query_meta_info_by_index(i) + results = self.api.query_by_index(i, 'cifar100') nodes, edges = parse_architecture_string(arch_info.arch_str) adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) for op in ops: if op not in active_nodes: active_nodes.add(op) + graph_list.append({ "adj_matrix": adj_matrix, "ops": ops, - "idx": i + "idx": i, + "train": [{ + "iepoch": result.get_train()['iepoch'], + "loss": result.get_train()['loss'], + "accuracy": result.get_train()['accuracy'], + "cur_time": result.get_train()['cur_time'], + "all_time": result.get_train()['all_time'], + "seed": seed, + }for seed, result in results.items()], + "valid": [{ + "iepoch": result.get_eval('x-valid')['iepoch'], + "loss": result.get_eval('x-valid')['loss'], + "accuracy": result.get_eval('x-valid')['accuracy'], + "cur_time": result.get_eval('x-valid')['cur_time'], + "all_time": result.get_eval('x-valid')['all_time'], + "seed": seed, + }for seed, result in results.items()], + "test": [{ + "iepoch": result.get_eval('x-test')['iepoch'], + "loss": result.get_eval('x-test')['loss'], + "accuracy": result.get_eval('x-test')['accuracy'], + "cur_time": result.get_eval('x-test')['cur_time'], + "all_time": result.get_eval('x-test')['all_time'], + "seed": seed, + }for seed, result in results.items()] }) data = graph_to_graph_data((adj_matrix, ops)) data_list.append(data) @@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos): adj_ops_pairs = [] for item in data: - adj_matrix = np.array(item['adjacency_matrix']) - ops = item['operations'] + adj_matrix = np.array(item['adj_matrix']) + ops = item['ops'] + ops = [op_type[op] for op in ops] adj_ops_pairs.append((adj_matrix, ops)) return adj_ops_pairs @@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos): # ops_type[op] = len(ops_type) # len_ops.add(len(ops)) # graphs.append((adj_matrix, ops)) - graphs = read_adj_ops_from_json(f'nasbench-201.meta.json') + graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') # check first five graphs for i in range(5): @@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index): 'transition_E': tansition_E.tolist(), } - with open(f'{root}/{source_name}.meta.json', "w") as f: + with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: json.dump(meta_dict, f) return meta_dict