From 99163a51508e72cc43b478a9d862cd4355a77a8c Mon Sep 17 00:00:00 2001 From: Hanzhang Ma Date: Tue, 11 Jun 2024 17:48:25 +0200 Subject: [PATCH] try to transfer the code from jupyter notebook to dataset.py --- graph_dit/datasets/dataset.py | 517 +++++++++++++++++++++++++++++++++- 1 file changed, 504 insertions(+), 13 deletions(-) diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 0c12a1f..3719a8b 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -2,6 +2,8 @@ import sys sys.path.append('../') +from nas_201_api import NASBench201API as API + import os import os.path as osp import pathlib @@ -24,7 +26,266 @@ from diffusion.distributions import DistributionNodes bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} +op_to_atom = { + 'input': 'Si', # Hydrogen for input + 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution + 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution + 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling + 'skip_connect': 'P', # Phosphorus for skip connection + 'none': 'S', # Sulfur for no operation + 'output': 'He' # Helium for output +} class DataModule(AbstractDataModule): + def __init__(self, cfg): + self.datadir = cfg.dataset.datadir + self.task = cfg.dataset.task_name + print("DataModule") + print("task", self.task) + print("datadir", self.datadir) + super().__init__(cfg) + + def prepare_data(self) -> None: + target = getattr(self.cfg.dataset, 'guidance_target', None) + print("target", target) + # try: + # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] + # except NameError: + # base_path = pathlib.Path(os.getcwd()).parent[2] + base_path = '/home/stud/hanzhang/Graph-Dit' + root_path = os.path.join(base_path, self.datadir) + self.root_path = root_path + + batch_size = self.cfg.train.batch_size + + num_workers = self.cfg.train.num_workers + pin_memory = self.cfg.dataset.pin_memory + + # Load the dataset to the memory + # Dataset has target property, root path, and transform + source = './NAS-Bench-201-v1_1-096897.pth' + dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) + + # if len(self.task.split('-')) == 2: + # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) + # else: + train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) + + self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index + train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) + if len(unlabeled_index) > 0: + train_index = torch.cat([train_index, unlabeled_index], dim=0) + + train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] + self.train_dataset = train_dataset + print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) + print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) + print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) + self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) + + self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) + self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) + + training_iterations = len(train_dataset) // batch_size + self.training_iterations = training_iterations + + def random_data_split(self, dataset): + nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() + labeled_len = len(dataset) - nan_count + full_idx = list(range(labeled_len)) + train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 + train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) + train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) + unlabeled_index = list(range(labeled_len, len(dataset))) + print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) + return train_index, val_index, test_index, unlabeled_index + + def fixed_split(self, dataset): + if self.task == 'O2-N2': + test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] + else: + raise ValueError('Invalid task name: {}'.format(self.task)) + full_idx = list(range(len(dataset))) + full_idx = list(set(full_idx) - set(test_index)) + train_ratio = 0.8 + train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) + print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) + return train_index, val_index, test_index, [] + + def get_train_smiles(self): + raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") + + def get_data_split(self): + raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") + + def example_batch(self): + return next(iter(self.val_loader)) + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.val_loader + + def test_dataloader(self): + return self.test_loader + +def graphs_to_json(graphs, filename): + bonds = { + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'input': 7, + 'output': 5, + 'none': 6 + } + + source_name = "nas-bench-201" + num_graph = len(graphs) + pt = Chem.GetPeriodicTable() + atom_name_list = [] + atom_count_list = [] + for i in range(2, 119): + atom_name_list.append(pt.GetElementSymbol(i)) + atom_count_list.append(0) + atom_name_list.append('*') + atom_count_list.append(0) + n_atoms_per_mol = [0] * 500 + bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] + bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} + valencies = [0] * 500 + transition_E = np.zeros((118, 118, 5)) + + n_atom_list = [] + n_bond_list = [] + # graphs = [(adj_matrix, ops), ...] + for graph in graphs: + ops = graph[1] + adj = graph[0] + n_atom = len(ops) + n_bond = len(ops) + n_atom_list.append(n_atom) + n_bond_list.append(n_bond) + + n_atoms_per_mol[n_atom] += 1 + cur_atom_count_arr = np.zeros(118) + + for op in ops: + symbol = op_to_atom[op] + if symbol == 'H': + continue + elif symbol == '*': + atom_count_list[-1] += 1 + cur_atom_count_arr[-1] += 1 + else: + atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 + cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 + # print('symbol', symbol) + # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) + # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') + try: + valencies[int(pt.GetDefaultValence(symbol))] += 1 + except: + print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) + transition_E_temp = np.zeros((118, 118, 5)) + # print(n_atom) + for i in range(n_atom): + for j in range(n_atom): + if i == j or adj[i][j] == 0: + continue + start_atom, end_atom = i, j + if ops[start_atom] == 'input' or ops[end_atom] == 'input': + continue + if ops[start_atom] == 'output' or ops[end_atom] == 'output': + continue + if ops[start_atom] == 'none' or ops[end_atom] == 'none': + continue + + start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 + end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 + bond_index = bonds[ops[end_atom]] + bond_count_list[bond_index] += 2 + + # print(start_index, end_index, bond_index) + + transition_E[start_index, end_index, bond_index] += 2 + transition_E[end_index, start_index, bond_index] += 2 + transition_E_temp[start_index, end_index, bond_index] += 2 + transition_E_temp[end_index, start_index, bond_index] += 2 + + bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 + print(bond_count_list) + cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 + # print(f'cur_tot_bond={cur_tot_bond}') + # find non-zero element in cur_tot_bond + # for i in range(118): + # for j in range(118): + # if cur_tot_bond[i][j] != 0: + # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') + # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 + # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") + transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) + # find non-zero element in transition_E + # for i in range(118): + # for j in range(118): + # if transition_E[i][j][0] != 0: + # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') + assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' + + n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] + + atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) + print('processed meta info: ------', filename, '------') + print('len atom_count_list', len(atom_count_list)) + print('len atom_name_list', len(atom_name_list)) + active_atoms = np.array(atom_name_list)[atom_count_list > 0] + active_atoms = active_atoms.tolist() + atom_count_list = atom_count_list.tolist() + + bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) + bond_count_list = bond_count_list.tolist() + valencies = np.array(valencies) / np.sum(valencies) + valencies = valencies.tolist() + + no_edge = np.sum(transition_E, axis=-1) == 0 + for i in range(118): + for j in range(118): + if no_edge[i][j] == False: + print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') + # print(f'no_edge: {no_edge}') + first_elt = transition_E[:, :, 0] + first_elt[no_edge] = 1 + transition_E[:, :, 0] = first_elt + + transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) + + # find non-zero element in transition_E again + for i in range(118): + for j in range(118): + if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: + print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') + + meta_dict = { + 'source': 'nasbench-201', + 'num_graph': num_graph, + 'n_atoms_per_mol_dist': n_atoms_per_mol[:51], + 'max_node': max(n_atom_list), + 'max_bond': max(n_bond_list), + 'atom_type_dist': atom_count_list, + 'bond_type_dist': bond_count_list, + 'valencies': valencies, + 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], + 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), + 'transition_E': transition_E.tolist(), + } + + with open(f'{filename}.meta.json', 'w') as f: + json.dump(meta_dict, f) + return meta_dict + +class DataModule_original(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir self.task = cfg.dataset.task_name @@ -48,18 +309,6 @@ class DataModule(AbstractDataModule): # Load the dataset to the memory # Dataset has target property, root path, and transform dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) - print("len dataset", len(dataset)) - def print_data(dataset): - print("dataset", dataset) - print("dataset keys", dataset.keys) - print("dataset x", dataset.x) - print("dataset edge_index", dataset.edge_index) - print("dataset edge_attr", dataset.edge_attr) - print("dataset y", dataset.y) - print("") - print_data(dataset=dataset[0]) - print_data(dataset=dataset[1]) - if len(self.task.split('-')) == 2: train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) @@ -138,8 +387,163 @@ class DataModule(AbstractDataModule): def test_dataloader(self): return self.test_loader +def graphs_to_json(graphs, filename): + bonds = { + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'input': 7, + 'output': 5, + 'none': 6 + } -class Dataset(InMemoryDataset): + source_name = "nas-bench-201" + num_graph = len(graphs) + pt = Chem.GetPeriodicTable() + atom_name_list = [] + atom_count_list = [] + for i in range(2, 119): + atom_name_list.append(pt.GetElementSymbol(i)) + atom_count_list.append(0) + atom_name_list.append('*') + atom_count_list.append(0) + n_atoms_per_mol = [0] * 500 + bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] + bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} + valencies = [0] * 500 + transition_E = np.zeros((118, 118, 5)) + + n_atom_list = [] + n_bond_list = [] + # graphs = [(adj_matrix, ops), ...] + for graph in graphs: + ops = graph[1] + adj = graph[0] + n_atom = len(ops) + n_bond = len(ops) + n_atom_list.append(n_atom) + n_bond_list.append(n_bond) + + n_atoms_per_mol[n_atom] += 1 + cur_atom_count_arr = np.zeros(118) + + for op in ops: + symbol = op_to_atom[op] + if symbol == 'H': + continue + elif symbol == '*': + atom_count_list[-1] += 1 + cur_atom_count_arr[-1] += 1 + else: + atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 + cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 + # print('symbol', symbol) + # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) + # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') + try: + valencies[int(pt.GetDefaultValence(symbol))] += 1 + except: + print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) + transition_E_temp = np.zeros((118, 118, 5)) + # print(n_atom) + for i in range(n_atom): + for j in range(n_atom): + if i == j or adj[i][j] == 0: + continue + start_atom, end_atom = i, j + if ops[start_atom] == 'input' or ops[end_atom] == 'input': + continue + if ops[start_atom] == 'output' or ops[end_atom] == 'output': + continue + if ops[start_atom] == 'none' or ops[end_atom] == 'none': + continue + + start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 + end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 + bond_index = bonds[ops[end_atom]] + bond_count_list[bond_index] += 2 + + # print(start_index, end_index, bond_index) + + transition_E[start_index, end_index, bond_index] += 2 + transition_E[end_index, start_index, bond_index] += 2 + transition_E_temp[start_index, end_index, bond_index] += 2 + transition_E_temp[end_index, start_index, bond_index] += 2 + + bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 + print(bond_count_list) + cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 + # print(f'cur_tot_bond={cur_tot_bond}') + # find non-zero element in cur_tot_bond + # for i in range(118): + # for j in range(118): + # if cur_tot_bond[i][j] != 0: + # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') + # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 + # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") + transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) + # find non-zero element in transition_E + # for i in range(118): + # for j in range(118): + # if transition_E[i][j][0] != 0: + # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') + assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' + + n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] + + atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) + print('processed meta info: ------', filename, '------') + print('len atom_count_list', len(atom_count_list)) + print('len atom_name_list', len(atom_name_list)) + active_atoms = np.array(atom_name_list)[atom_count_list > 0] + active_atoms = active_atoms.tolist() + atom_count_list = atom_count_list.tolist() + + bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) + bond_count_list = bond_count_list.tolist() + valencies = np.array(valencies) / np.sum(valencies) + valencies = valencies.tolist() + + no_edge = np.sum(transition_E, axis=-1) == 0 + for i in range(118): + for j in range(118): + if no_edge[i][j] == False: + print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') + # print(f'no_edge: {no_edge}') + first_elt = transition_E[:, :, 0] + first_elt[no_edge] = 1 + transition_E[:, :, 0] = first_elt + + transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) + + # find non-zero element in transition_E again + for i in range(118): + for j in range(118): + if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: + print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') + + meta_dict = { + 'source': 'nasbench-201', + 'num_graph': num_graph, + 'n_atoms_per_mol_dist': n_atoms_per_mol[:51], + 'max_node': max(n_atom_list), + 'max_bond': max(n_bond_list), + 'atom_type_dist': atom_count_list, + 'bond_type_dist': bond_count_list, + 'valencies': valencies, + 'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], + 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), + 'transition_E': transition_E.tolist(), + } + + with open(f'{filename}.meta.json', 'w') as f: + json.dump(meta_dict, f) + return meta_dict + +class Dataset_origin(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop @@ -223,8 +627,95 @@ class Dataset(InMemoryDataset): torch.save(self.collate(data_list), self.processed_paths[0]) +def parse_architecture_string(arch_str): + print(arch_str) + steps = arch_str.split('+') + nodes = ['input'] # Start with input node + edges = [] + for i, step in enumerate(steps): + step = step.strip('|').split('|') + for node in step: + op, idx = node.split('~') + edges.append((int(idx), i+1)) # i+1 because 0 is input node + nodes.append(op) + nodes.append('output') # Add output node + return nodes, edges +def create_adj_matrix_and_ops(nodes, edges): + num_nodes = len(nodes) + adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int) + for (src, dst) in edges: + adj_matrix[src][dst] = 1 + return adj_matrix, nodes class DataInfos(AbstractDatasetInfos): + def __init__(self, datamodule, cfg): + tasktype_dict = { + 'hiv_b': 'classification', + 'bace_b': 'classification', + 'bbbp_b': 'classification', + 'O2': 'regression', + 'N2': 'regression', + 'CO2': 'regression', + } + task_name = cfg.dataset.task_name + self.task = task_name + self.task_type = tasktype_dict.get(task_name, "regression") + self.ensure_connected = cfg.model.ensure_connected + + datadir = cfg.dataset.datadir + + base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] + meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') + data_root = os.path.join(base_path, datadir, 'raw') + graphs = [] + length = 15625 + ops_type = {} + len_ops = set() + api = API('../NAS-Bench-201-v1_0-e61699.pth') + for i in range(length): + arch_info = api.query_meta_info_by_index(i) + nodes, edges = parse_architecture_string(arch_info.arch_str) + adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) + if i < 5: + print("Adjacency Matrix:") + print(adj_matrix) + print("Operations List:") + print(ops) + for op in ops: + if op not in ops_type: + ops_type[op] = len(ops_type) + len_ops.add(len(ops)) + graphs.append((adj_matrix, ops)) + + meta_dict = graphs_to_json(graphs, 'nasbench-201') + + self.base_path = base_path + self.active_atoms = meta_dict['active_atoms'] + self.max_n_nodes = meta_dict['max_node'] + self.original_max_n_nodes = meta_dict['max_node'] + self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) + self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) + self.transition_E = torch.Tensor(meta_dict['transition_E']) + + self.atom_decoder = meta_dict['active_atoms'] + node_types = torch.Tensor(meta_dict['atom_type_dist']) + active_index = (node_types > 0).nonzero().squeeze() + self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] + self.nodes_dist = DistributionNodes(self.n_nodes) + self.active_index = active_index + + val_len = 3 * self.original_max_n_nodes - 2 + meta_val = torch.Tensor(meta_dict['valencies']) + self.valency_distribution = torch.zeros(val_len) + val_len = min(val_len, len(meta_val)) + self.valency_distribution[:val_len] = meta_val[:val_len] + self.y_prior = None + self.train_ymin = [] + self.train_ymax = [] + + + +class DataInfos_origin(AbstractDatasetInfos): def __init__(self, datamodule, cfg): tasktype_dict = { 'hiv_b': 'classification',