update the new graph to json function
This commit is contained in:
parent
222470a43c
commit
df26eef77c
@ -39,6 +39,16 @@ op_to_atom = {
|
|||||||
'none': 'S', # Sulfur for no operation
|
'none': 'S', # Sulfur for no operation
|
||||||
'output': 'He' # Helium for output
|
'output': 'He' # Helium for output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op_type = {
|
||||||
|
'nor_conv_1x1': 1,
|
||||||
|
'nor_conv_3x3': 2,
|
||||||
|
'avg_pool_3x3': 3,
|
||||||
|
'skip_connect': 4,
|
||||||
|
'output': 5,
|
||||||
|
'none': 6,
|
||||||
|
'input': 7
|
||||||
|
}
|
||||||
class DataModule(AbstractDataModule):
|
class DataModule(AbstractDataModule):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.datadir = cfg.dataset.datadir
|
self.datadir = cfg.dataset.datadir
|
||||||
@ -343,6 +353,121 @@ class DataModule_original(AbstractDataModule):
|
|||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.test_loader
|
return self.test_loader
|
||||||
|
|
||||||
|
def new_graphs_to_json(graphs, filename):
|
||||||
|
source_name = "nasbench-201"
|
||||||
|
num_graph = len(graphs)
|
||||||
|
|
||||||
|
node_name_list = []
|
||||||
|
node_count_list = []
|
||||||
|
|
||||||
|
for op_name in op_type:
|
||||||
|
node_name_list.append(op_name)
|
||||||
|
node_count_list.append(0)
|
||||||
|
|
||||||
|
node_name_list.append('*')
|
||||||
|
node_count_list.append(0)
|
||||||
|
n_nodes_per_graph = [0] * num_graph
|
||||||
|
edge_count_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
valencies = [0] * (len(op_type) + 1)
|
||||||
|
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||||
|
|
||||||
|
n_node_list = []
|
||||||
|
n_edge_list = []
|
||||||
|
|
||||||
|
for graph in graphs:
|
||||||
|
ops = graph[1]
|
||||||
|
adj = graph[0]
|
||||||
|
|
||||||
|
n_node = len(ops)
|
||||||
|
n_edge = len(ops)
|
||||||
|
n_node_list.append(n_node)
|
||||||
|
n_edge_list.append(n_edge)
|
||||||
|
|
||||||
|
n_nodes_per_graph[n_node] += 1
|
||||||
|
cur_node_count_arr = np.zeros(len(op_type) + 1)
|
||||||
|
|
||||||
|
for op in ops:
|
||||||
|
node = op
|
||||||
|
if node == '*':
|
||||||
|
node_count_list[-1] += 1
|
||||||
|
cur_node_count_arr[-1] += 1
|
||||||
|
else:
|
||||||
|
node_count_list[op_type[node]] += 1
|
||||||
|
cur_node_count_arr[op_type[node]] += 1
|
||||||
|
try:
|
||||||
|
valencies[int(op_type[node])] += 1
|
||||||
|
except:
|
||||||
|
print('int(op_type[node])', int(op_type[node]))
|
||||||
|
|
||||||
|
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
||||||
|
for i in range(n_node):
|
||||||
|
for j in range(n_node):
|
||||||
|
if i == j or adj[i][j] == 0:
|
||||||
|
continue
|
||||||
|
start_node, end_node = i, j
|
||||||
|
|
||||||
|
start_index = op_type[ops[start_node]]
|
||||||
|
end_index = op_type[ops[end_node]]
|
||||||
|
bond_index = 1
|
||||||
|
edge_count_list[bond_index] += 2
|
||||||
|
|
||||||
|
transition_E[start_index, end_index, bond_index] += 2
|
||||||
|
transition_E[end_index, start_index, bond_index] += 2
|
||||||
|
transition_E_temp[start_index, end_index, bond_index] += 2
|
||||||
|
transition_E_temp[end_index, start_index, bond_index] += 2
|
||||||
|
|
||||||
|
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
|
||||||
|
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
|
||||||
|
print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
||||||
|
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
|
||||||
|
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
|
||||||
|
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
|
||||||
|
|
||||||
|
n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph)
|
||||||
|
n_nodes_per_graph = n_nodes_per_graph.tolist()[:51]
|
||||||
|
|
||||||
|
node_count_list = np.array(node_count_list) / np.sum(node_count_list)
|
||||||
|
print('processed meta info: ------', filename, '------')
|
||||||
|
print('len node_count_list', len(node_count_list))
|
||||||
|
print('len node_name_list', len(node_name_list))
|
||||||
|
active_nodes = np.array(node_name_list)[node_count_list > 0]
|
||||||
|
active_nodes = active_nodes.tolist()
|
||||||
|
node_count_list = node_count_list.tolist()
|
||||||
|
|
||||||
|
edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list)
|
||||||
|
edge_count_list = edge_count_list.tolist()
|
||||||
|
valencies = np.array(valencies) / np.sum(valencies)
|
||||||
|
valencies = valencies.tolist()
|
||||||
|
|
||||||
|
no_edge = np.sum(transition_E, axis=-1) == 0
|
||||||
|
first_elt = transition_E[:, :, 0]
|
||||||
|
first_elt[no_edge] = 1
|
||||||
|
transition_E[:, :, 0] = first_elt
|
||||||
|
|
||||||
|
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
meta_dict = {
|
||||||
|
'source': source_name,
|
||||||
|
'num_graph': num_graph,
|
||||||
|
'n_nodes_per_graph': n_nodes_per_graph,
|
||||||
|
'max_n_nodes': max(n_node_list),
|
||||||
|
'max_n_edges': max(n_edge_list),
|
||||||
|
'node_type_list': node_count_list,
|
||||||
|
'edge_type_list': edge_count_list,
|
||||||
|
'valencies': valencies,
|
||||||
|
'active_nodes': active_nodes,
|
||||||
|
'num_active_nodes': len(active_nodes),
|
||||||
|
'transition_E': transition_E.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f'{filename}.meta.json', 'w') as f:
|
||||||
|
json.dump(meta_dict, f)
|
||||||
|
|
||||||
|
return meta_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def graphs_to_json(graphs, filename):
|
def graphs_to_json(graphs, filename):
|
||||||
bonds = {
|
bonds = {
|
||||||
'nor_conv_1x1': 1,
|
'nor_conv_1x1': 1,
|
||||||
@ -490,7 +615,7 @@ def graphs_to_json(graphs, filename):
|
|||||||
'atom_type_dist': atom_count_list,
|
'atom_type_dist': atom_count_list,
|
||||||
'bond_type_dist': bond_count_list,
|
'bond_type_dist': bond_count_list,
|
||||||
'valencies': valencies,
|
'valencies': valencies,
|
||||||
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
|
'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
|
||||||
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
|
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
|
||||||
'transition_E': transition_E.tolist(),
|
'transition_E': transition_E.tolist(),
|
||||||
}
|
}
|
||||||
@ -503,10 +628,10 @@ class Dataset(InMemoryDataset):
|
|||||||
self.target_prop = target_prop
|
self.target_prop = target_prop
|
||||||
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
self.source = source
|
self.source = source
|
||||||
super().__init__(root, transform, pre_transform, pre_filter)
|
|
||||||
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
|
|
||||||
self.api = API(source) # Initialize NAS-Bench-201 API
|
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
print('API loaded')
|
print('API loaded')
|
||||||
|
super().__init__(root, transform, pre_transform, pre_filter)
|
||||||
|
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
|
||||||
print('Dataset initialized')
|
print('Dataset initialized')
|
||||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||||
self.data.edge_attr = self.data.edge_attr.squeeze()
|
self.data.edge_attr = self.data.edge_attr.squeeze()
|
||||||
@ -732,30 +857,35 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
arch_info = self.api.query_meta_info_by_index(i)
|
arch_info = self.api.query_meta_info_by_index(i)
|
||||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||||
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
||||||
if i < 5:
|
# if i < 5:
|
||||||
print("Adjacency Matrix:")
|
# print("Adjacency Matrix:")
|
||||||
print(adj_matrix)
|
# print(adj_matrix)
|
||||||
print("Operations List:")
|
# print("Operations List:")
|
||||||
print(ops)
|
# print(ops)
|
||||||
for op in ops:
|
for op in ops:
|
||||||
if op not in ops_type:
|
if op not in ops_type:
|
||||||
ops_type[op] = len(ops_type)
|
ops_type[op] = len(ops_type)
|
||||||
len_ops.add(len(ops))
|
len_ops.add(len(ops))
|
||||||
graphs.append((adj_matrix, ops))
|
graphs.append((adj_matrix, ops))
|
||||||
|
|
||||||
meta_dict = graphs_to_json(graphs, 'nasbench-201')
|
# check first five graphs
|
||||||
|
for i in range(5):
|
||||||
|
print(f'graph {i} : {graphs[i]}')
|
||||||
|
print(f'ops_type: {ops_type}')
|
||||||
|
|
||||||
|
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
|
||||||
self.base_path = base_path
|
self.base_path = base_path
|
||||||
self.active_atoms = meta_dict['active_atoms']
|
self.active_nodes = meta_dict['active_nodes']
|
||||||
self.max_n_nodes = meta_dict['max_node']
|
self.max_n_nodes = meta_dict['max_n_nodes']
|
||||||
self.original_max_n_nodes = meta_dict['max_node']
|
self.original_max_n_nodes = meta_dict['max_n_nodes']
|
||||||
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
|
self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph'])
|
||||||
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
|
self.edge_types = torch.Tensor(meta_dict['edge_type_dist'])
|
||||||
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
||||||
|
|
||||||
self.atom_decoder = meta_dict['active_atoms']
|
self.node_decoder = meta_dict['active_nodes']
|
||||||
node_types = torch.Tensor(meta_dict['atom_type_dist'])
|
node_types = torch.Tensor(meta_dict['node_type_dist'])
|
||||||
active_index = (node_types > 0).nonzero().squeeze()
|
active_index = (node_types > 0).nonzero().squeeze()
|
||||||
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
|
self.node_types = torch.Tensor(meta_dict['node_type_dist'])[active_index]
|
||||||
self.nodes_dist = DistributionNodes(self.n_nodes)
|
self.nodes_dist = DistributionNodes(self.n_nodes)
|
||||||
self.active_index = active_index
|
self.active_index = active_index
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user