update a small problem
This commit is contained in:
parent
0fc6f6e686
commit
be8bb16f61
@ -359,15 +359,15 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
|
|
||||||
node_name_list = []
|
node_name_list = []
|
||||||
node_count_list = []
|
node_count_list = []
|
||||||
|
node_name_list.append('*')
|
||||||
|
|
||||||
for op_name in op_type:
|
for op_name in op_type:
|
||||||
node_name_list.append(op_name)
|
node_name_list.append(op_name)
|
||||||
node_count_list.append(0)
|
node_count_list.append(0)
|
||||||
|
|
||||||
node_name_list.append('*')
|
|
||||||
node_count_list.append(0)
|
node_count_list.append(0)
|
||||||
n_nodes_per_graph = [0] * num_graph
|
n_nodes_per_graph = [0] * num_graph
|
||||||
edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
edge_count_list = [0, 0]
|
||||||
valencies = [0] * (len(op_type) + 1)
|
valencies = [0] * (len(op_type) + 1)
|
||||||
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||||
|
|
||||||
@ -388,16 +388,16 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
node = op
|
node = op
|
||||||
if node == '*':
|
# if node == '*':
|
||||||
node_count_list[-1] += 1
|
# node_count_list[-1] += 1
|
||||||
cur_node_count_arr[-1] += 1
|
# cur_node_count_arr[-1] += 1
|
||||||
else:
|
# else:
|
||||||
node_count_list[op_type[node]] += 1
|
node_count_list[node] += 1
|
||||||
cur_node_count_arr[op_type[node]] += 1
|
cur_node_count_arr[node] += 1
|
||||||
try:
|
try:
|
||||||
valencies[int(op_type[node])] += 1
|
valencies[node] += 1
|
||||||
except:
|
except:
|
||||||
print('int(op_type[node])', int(op_type[node]))
|
print('int(op_type[node])', int(node))
|
||||||
|
|
||||||
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||||
for i in range(n_node):
|
for i in range(n_node):
|
||||||
@ -406,8 +406,8 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
continue
|
continue
|
||||||
start_node, end_node = i, j
|
start_node, end_node = i, j
|
||||||
|
|
||||||
start_index = op_type[ops[start_node]]
|
start_index = ops[start_node]
|
||||||
end_index = op_type[ops[end_node]]
|
end_index = ops[end_node]
|
||||||
bond_index = 1
|
bond_index = 1
|
||||||
edge_count_list[bond_index] += 2
|
edge_count_list[bond_index] += 2
|
||||||
|
|
||||||
@ -418,7 +418,7 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
|
|
||||||
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
|
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
|
||||||
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
|
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
|
||||||
print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
# print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
||||||
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
|
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
|
||||||
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
|
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
|
||||||
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
|
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
|
||||||
@ -460,7 +460,7 @@ def new_graphs_to_json(graphs, filename):
|
|||||||
'transition_E': transition_E.tolist(),
|
'transition_E': transition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(f'{filename}.meta.json', 'w') as f:
|
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
@ -683,15 +683,41 @@ class Dataset(InMemoryDataset):
|
|||||||
active_nodes = set()
|
active_nodes = set()
|
||||||
for i in range(len_data):
|
for i in range(len_data):
|
||||||
arch_info = self.api.query_meta_info_by_index(i)
|
arch_info = self.api.query_meta_info_by_index(i)
|
||||||
|
results = self.api.query_by_index(i, 'cifar100')
|
||||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||||
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
||||||
for op in ops:
|
for op in ops:
|
||||||
if op not in active_nodes:
|
if op not in active_nodes:
|
||||||
active_nodes.add(op)
|
active_nodes.add(op)
|
||||||
|
|
||||||
graph_list.append({
|
graph_list.append({
|
||||||
"adj_matrix": adj_matrix,
|
"adj_matrix": adj_matrix,
|
||||||
"ops": ops,
|
"ops": ops,
|
||||||
"idx": i
|
"idx": i,
|
||||||
|
"train": [{
|
||||||
|
"iepoch": result.get_train()['iepoch'],
|
||||||
|
"loss": result.get_train()['loss'],
|
||||||
|
"accuracy": result.get_train()['accuracy'],
|
||||||
|
"cur_time": result.get_train()['cur_time'],
|
||||||
|
"all_time": result.get_train()['all_time'],
|
||||||
|
"seed": seed,
|
||||||
|
}for seed, result in results.items()],
|
||||||
|
"valid": [{
|
||||||
|
"iepoch": result.get_eval('x-valid')['iepoch'],
|
||||||
|
"loss": result.get_eval('x-valid')['loss'],
|
||||||
|
"accuracy": result.get_eval('x-valid')['accuracy'],
|
||||||
|
"cur_time": result.get_eval('x-valid')['cur_time'],
|
||||||
|
"all_time": result.get_eval('x-valid')['all_time'],
|
||||||
|
"seed": seed,
|
||||||
|
}for seed, result in results.items()],
|
||||||
|
"test": [{
|
||||||
|
"iepoch": result.get_eval('x-test')['iepoch'],
|
||||||
|
"loss": result.get_eval('x-test')['loss'],
|
||||||
|
"accuracy": result.get_eval('x-test')['accuracy'],
|
||||||
|
"cur_time": result.get_eval('x-test')['cur_time'],
|
||||||
|
"all_time": result.get_eval('x-test')['all_time'],
|
||||||
|
"seed": seed,
|
||||||
|
}for seed, result in results.items()]
|
||||||
})
|
})
|
||||||
data = graph_to_graph_data((adj_matrix, ops))
|
data = graph_to_graph_data((adj_matrix, ops))
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
@ -925,8 +951,9 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
|
|
||||||
adj_ops_pairs = []
|
adj_ops_pairs = []
|
||||||
for item in data:
|
for item in data:
|
||||||
adj_matrix = np.array(item['adjacency_matrix'])
|
adj_matrix = np.array(item['adj_matrix'])
|
||||||
ops = item['operations']
|
ops = item['ops']
|
||||||
|
ops = [op_type[op] for op in ops]
|
||||||
adj_ops_pairs.append((adj_matrix, ops))
|
adj_ops_pairs.append((adj_matrix, ops))
|
||||||
|
|
||||||
return adj_ops_pairs
|
return adj_ops_pairs
|
||||||
@ -944,7 +971,7 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
# ops_type[op] = len(ops_type)
|
# ops_type[op] = len(ops_type)
|
||||||
# len_ops.add(len(ops))
|
# len_ops.add(len(ops))
|
||||||
# graphs.append((adj_matrix, ops))
|
# graphs.append((adj_matrix, ops))
|
||||||
graphs = read_adj_ops_from_json(f'nasbench-201.meta.json')
|
graphs = read_adj_ops_from_json(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
||||||
|
|
||||||
# check first five graphs
|
# check first five graphs
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@ -1158,7 +1185,7 @@ def compute_meta(root, source_name, train_index, test_index):
|
|||||||
'transition_E': tansition_E.tolist(),
|
'transition_E': tansition_E.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(f'{root}/{source_name}.meta.json', "w") as f:
|
with open(f'/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
Loading…
Reference in New Issue
Block a user