nasbenchipynb
This commit is contained in:
		| @@ -21,7 +21,7 @@ from sklearn.model_selection import train_test_split | |||||||
| import utils as utils | import utils as utils | ||||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||||
| from diffusion.distributions import DistributionNodes | from diffusion.distributions import DistributionNodes | ||||||
| from nas_201_api import NASBench201API |  | ||||||
| bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||||
|  |  | ||||||
| class DataModule(AbstractDataModule): | class DataModule(AbstractDataModule): | ||||||
| @@ -48,17 +48,17 @@ class DataModule(AbstractDataModule): | |||||||
|         # Load the dataset to the memory |         # Load the dataset to the memory | ||||||
|         # Dataset has target property, root path, and transform |         # Dataset has target property, root path, and transform | ||||||
|         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) |         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) | ||||||
|         # print("len dataset", len(dataset)) |         print("len dataset", len(dataset)) | ||||||
|         # def print_data(dataset): |         def print_data(dataset): | ||||||
|         #     print("dataset", dataset) |             print("dataset", dataset) | ||||||
|         #     print("dataset keys", dataset.keys) |             print("dataset keys", dataset.keys) | ||||||
|         #     print("dataset x", dataset.x) |             print("dataset x", dataset.x) | ||||||
|         #     print("dataset edge_index", dataset.edge_index) |             print("dataset edge_index", dataset.edge_index) | ||||||
|         #     print("dataset edge_attr", dataset.edge_attr) |             print("dataset edge_attr", dataset.edge_attr) | ||||||
|         #     print("dataset y", dataset.y) |             print("dataset y", dataset.y) | ||||||
|         #     print("") |             print("") | ||||||
|         # print_data(dataset=dataset[0]) |         print_data(dataset=dataset[0]) | ||||||
|         # print_data(dataset=dataset[1]) |         print_data(dataset=dataset[1]) | ||||||
|  |  | ||||||
|  |  | ||||||
|         if len(self.task.split('-')) == 2: |         if len(self.task.split('-')) == 2: | ||||||
| @@ -155,30 +155,7 @@ class Dataset(InMemoryDataset): | |||||||
|     def processed_file_names(self): |     def processed_file_names(self): | ||||||
|         return [f'{self.source}.pt'] |         return [f'{self.source}.pt'] | ||||||
|  |  | ||||||
|     def create_adj_matrix_and_ops(nodes, edges): |  | ||||||
|         num_nodes = len(nodes) |  | ||||||
|         adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int) |  | ||||||
|         for (src, dst) in edges: |  | ||||||
|             adj_matrix[src][dst] = 1 |  | ||||||
|         return adj_matrix, nodes |  | ||||||
|      |  | ||||||
|     def parse_architecture_string(arch_str): |  | ||||||
|         print(arch_str) |  | ||||||
|         steps = arch_str.split('+') |  | ||||||
|         nodes = ['input']  # Start with input node |  | ||||||
|         edges = [] |  | ||||||
|         for i, step in enumerate(steps): |  | ||||||
|             step = step.strip('|').split('|') |  | ||||||
|             for node in step: |  | ||||||
|                 op, idx = node.split('~') |  | ||||||
|                 edges.append((int(idx), i+1))  # i+1 because 0 is input node |  | ||||||
|                 nodes.append(op) |  | ||||||
|         nodes.append('output')  # Add output node |  | ||||||
|         return nodes, edges |  | ||||||
|      |  | ||||||
|     def process(self): |     def process(self): | ||||||
|         # return |  | ||||||
|         api = NASBench201API('./NAS-Bench-201-v1_1-096897.pth') |  | ||||||
|         RDLogger.DisableLog('rdApp.*') |         RDLogger.DisableLog('rdApp.*') | ||||||
|         data_path = osp.join(self.raw_dir, self.raw_file_names[0]) |         data_path = osp.join(self.raw_dir, self.raw_file_names[0]) | ||||||
|         data_df = pd.read_csv(data_path) |         data_df = pd.read_csv(data_path) | ||||||
| @@ -223,65 +200,26 @@ class Dataset(InMemoryDataset): | |||||||
|             return data, active_atoms |             return data, active_atoms | ||||||
|          |          | ||||||
|         # Loop through every row in the DataFrame and apply the function |         # Loop through every row in the DataFrame and apply the function | ||||||
|         # data_list = [] |  | ||||||
|  |  | ||||||
|         # len_data = len(data_df) |  | ||||||
|         len_data = 15625 |  | ||||||
|         data_list = [] |         data_list = [] | ||||||
|         bonds = { |         len_data = len(data_df) | ||||||
|             'nor_conv_1x1': 1, |         with tqdm(total=len_data) as pbar: | ||||||
|             'nor_conv_3x3': 2, |             # --- data processing start --- | ||||||
|             'avg_pool_3x3': 3, |             active_atoms = set() | ||||||
|             'skip_connect': 4, |             for i, (sms, df_row) in enumerate(data_df.iterrows()): | ||||||
|             'input': 0, |                 if i == sms: | ||||||
|             'output': 5 |                     sms = df_row['smiles'] | ||||||
|         } |                 mol = Chem.MolFromSmiles(sms, sanitize=False) | ||||||
|  |                 if len(self.target_prop.split('-')) == 2: | ||||||
|         def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): |                     target1, target2 = self.target_prop.split('-') | ||||||
|             nodes, edges = Dataset.parse_architecture_string(arch_str) |                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) | ||||||
|             node_labels = [bonds[node] for node in nodes]  # Replace with appropriate encoding if necessary |                 elif len(self.target_prop.split('-')) == 3: | ||||||
|             x = torch.LongTensor(node_labels) |                     target1, target2, target3 = self.target_prop.split('-') | ||||||
|  |                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) | ||||||
|             edges_list = [(start, end) for start, end in edges] |                 else: | ||||||
|             edge_type = [bonds[nodes[end]] for start, end in edges]  # Example: using end node type as edge type |                     data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) | ||||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous() |                 active_atoms.update(cur_active_atoms) | ||||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) |                 data_list.append(data) | ||||||
|             edge_attr = edge_type.view(-1, 1) |                 pbar.update(1) | ||||||
|  |  | ||||||
|             if target3 is not None: |  | ||||||
|                 y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1) |  | ||||||
|             elif target2 is not None: |  | ||||||
|                 y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1) |  | ||||||
|             else: |  | ||||||
|                 y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) |  | ||||||
|  |  | ||||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) |  | ||||||
|             return data, nodes |  | ||||||
|         # --- data processing start --- |  | ||||||
|         # active_atoms = set() |  | ||||||
|         # for i, (sms, df_row) in enumerate(data_df.iterrows()): |  | ||||||
|         #     if i == sms: |  | ||||||
|         #         sms = df_row['smiles'] |  | ||||||
|         #     mol = Chem.MolFromSmiles(sms, sanitize=False) |  | ||||||
|         #     if len(self.target_prop.split('-')) == 2: |  | ||||||
|         #         target1, target2 = self.target_prop.split('-') |  | ||||||
|         #         data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) |  | ||||||
|         #     elif len(self.target_prop.split('-')) == 3: |  | ||||||
|         #         target1, target2, target3 = self.target_prop.split('-') |  | ||||||
|         #         data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) |  | ||||||
|         #     else: |  | ||||||
|         #         data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) |  | ||||||
|         #     active_atoms.update(cur_active_atoms) |  | ||||||
|         #     data_list.append(data) |  | ||||||
|         #     pbar.update(1) |  | ||||||
|         for arch_index in range(len_data): |  | ||||||
|             arch_info = api.get_arch(arch_index) |  | ||||||
|             arch_str = arch_info['arch_str'] |  | ||||||
|  |  | ||||||
|             nodes, edges = Dataset.parse_architecture_string(arch_str) |  | ||||||
|             adj_matrix, nodes = Dataset.create_adj_matrix_and_ops(nodes, edges) |  | ||||||
|  |  | ||||||
|             data, cur_active_atoms = graph |  | ||||||
|  |  | ||||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) |         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||||
|  |  | ||||||
| @@ -296,10 +234,8 @@ class DataInfos(AbstractDatasetInfos): | |||||||
|             'N2': 'regression', |             'N2': 'regression', | ||||||
|             'CO2': 'regression', |             'CO2': 'regression', | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         task_name = cfg.dataset.task_name |         task_name = cfg.dataset.task_name | ||||||
|         self.task = task_name |         self.task = task_name | ||||||
|         print(self.task) |  | ||||||
|         self.task_type = tasktype_dict.get(task_name, "regression") |         self.task_type = tasktype_dict.get(task_name, "regression") | ||||||
|         self.ensure_connected = cfg.model.ensure_connected |         self.ensure_connected = cfg.model.ensure_connected | ||||||
|  |  | ||||||
| @@ -473,181 +409,6 @@ def compute_meta(root, source_name, train_index, test_index): | |||||||
|      |      | ||||||
|     return meta_dict |     return meta_dict | ||||||
|  |  | ||||||
| op_to_atom = { |  | ||||||
|     'input': 'Si',          # Hydrogen for input |  | ||||||
|     'nor_conv_1x1': 'C',   # Carbon for 1x1 convolution |  | ||||||
|     'nor_conv_3x3': 'N',   # Nitrogen for 3x3 convolution |  | ||||||
|     'avg_pool_3x3': 'O',   # Oxygen for 3x3 average pooling |  | ||||||
|     'skip_connect': 'P',   # Phosphorus for skip connection |  | ||||||
|     'none': 'S',           # Sulfur for no operation |  | ||||||
|     'output': 'He'         # Helium for output |  | ||||||
| } |  | ||||||
|  |  | ||||||
| def get_sample_nasbench_graph(): |  | ||||||
|     adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], |  | ||||||
|                         [0, 0, 0, 1, 0, 1, 0, 0], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 1, 0], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 1, 0], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 1], |  | ||||||
|                         [0, 0, 0, 0, 0, 0, 0, 0]]) |  | ||||||
|     ops = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] |  | ||||||
|     return adj_mat, ops |  | ||||||
|  |  | ||||||
| def nasbench_to_molecule(adj_mat, ops): |  | ||||||
|     mol = Chem.RWMol()  # Create a new editable molecule |  | ||||||
|     atom_map = {}       # Map to keep track of node to atom mapping |  | ||||||
|      |  | ||||||
|     # Add atoms to the molecule |  | ||||||
|     for i, op in enumerate(ops): |  | ||||||
|         atom_type = op_to_atom.get(op, 'C')  # Default to Carbon if operation not found |  | ||||||
|         atom = Chem.Atom(atom_type)          # Create an atom of the specified type |  | ||||||
|         idx = mol.AddAtom(atom) |  | ||||||
|         atom_map[i] = idx |  | ||||||
|  |  | ||||||
|     # Add bonds to the molecule |  | ||||||
|     for i in range(adj_mat.shape[0]): |  | ||||||
|         for j in range(adj_mat.shape[1]): |  | ||||||
|             if adj_mat[i, j] == 1: |  | ||||||
|                 mol.AddBond(atom_map[i], atom_map[j], Chem.rdchem.BondType.SINGLE) |  | ||||||
|  |  | ||||||
|     return mol |  | ||||||
|  |  | ||||||
| def compute_meta_graph(root, source_name, train_index, test_index): |  | ||||||
|     # initialize the periodic table |  | ||||||
|     # 118 elements + 1 for * |  | ||||||
|     # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. |  | ||||||
|     pt = Chem.GetPeriodicTable() |  | ||||||
|     atom_name_list = [] |  | ||||||
|     atom_count_list = [] |  | ||||||
|     for i in range(2, 119): |  | ||||||
|         atom_name_list.append(pt.GetElementSymbol(i)) |  | ||||||
|         atom_count_list.append(0) |  | ||||||
|     atom_name_list.append('*') |  | ||||||
|     atom_count_list.append(0) |  | ||||||
|     n_atoms_per_mol = [0] * 500 |  | ||||||
|     bond_count_list = [0, 0, 0, 0, 0] |  | ||||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} |  | ||||||
|     valencies = [0] * 500 |  | ||||||
|     tansition_E = np.zeros((118, 118, 5)) |  | ||||||
|      |  | ||||||
|     # Load the data from the source file |  | ||||||
|     filename = f'{source_name}.csv.gz' |  | ||||||
|     df = pd.read_csv(f'{root}/{filename}') |  | ||||||
|     all_index = list(range(len(df))) |  | ||||||
|     non_test_index = list(set(all_index) - set(test_index)) |  | ||||||
|     df = df.iloc[non_test_index] |  | ||||||
|     # extract the smiles from the dataframe |  | ||||||
|     tot_smiles = df['smiles'].tolist() |  | ||||||
|  |  | ||||||
|     n_atom_list = [] |  | ||||||
|     n_bond_list = [] |  | ||||||
|     for i, sms in enumerate(tot_smiles): |  | ||||||
|         try: |  | ||||||
|             mol = Chem.MolFromSmiles(sms) |  | ||||||
|         except: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         n_atom = mol.GetNumHeavyAtoms() |  | ||||||
|         n_bond = mol.GetNumBonds() |  | ||||||
|         n_atom_list.append(n_atom) |  | ||||||
|         n_bond_list.append(n_bond) |  | ||||||
|  |  | ||||||
|         n_atoms_per_mol[n_atom] += 1 |  | ||||||
|         cur_atom_count_arr = np.zeros(118) |  | ||||||
|  |  | ||||||
|         for atom in mol.GetAtoms(): |  | ||||||
|             symbol = atom.GetSymbol() |  | ||||||
|             if symbol == 'H': |  | ||||||
|                 continue |  | ||||||
|             elif symbol == '*': |  | ||||||
|                 atom_count_list[-1] += 1 |  | ||||||
|                 cur_atom_count_arr[-1] += 1 |  | ||||||
|             else: |  | ||||||
|                 atom_count_list[atom.GetAtomicNum()-2] += 1 |  | ||||||
|                 cur_atom_count_arr[atom.GetAtomicNum()-2] += 1 |  | ||||||
|                 try: |  | ||||||
|                     valencies[int(atom.GetExplicitValence())] += 1 |  | ||||||
|                 except: |  | ||||||
|                     print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence())) |  | ||||||
|          |  | ||||||
|         tansition_E_temp = np.zeros((118, 118, 5)) |  | ||||||
|         for bond in mol.GetBonds(): |  | ||||||
|             start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() |  | ||||||
|             if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H': |  | ||||||
|                 continue |  | ||||||
|              |  | ||||||
|             if start_atom.GetSymbol() == '*': |  | ||||||
|                 start_index = 117 |  | ||||||
|             else: |  | ||||||
|                 start_index = start_atom.GetAtomicNum() - 2 |  | ||||||
|             if end_atom.GetSymbol() == '*': |  | ||||||
|                 end_index = 117 |  | ||||||
|             else: |  | ||||||
|                 end_index = end_atom.GetAtomicNum() - 2 |  | ||||||
|  |  | ||||||
|             bond_type = bond.GetBondType() |  | ||||||
|             bond_index = bond_type_to_index[bond_type] |  | ||||||
|             bond_count_list[bond_index] += 2 |  | ||||||
|  |  | ||||||
|             # Update the transition matrix |  | ||||||
|             # The transition matrix is symmetric, so we update both directions |  | ||||||
|             # We also update the temporary transition matrix to check for errors |  | ||||||
|             # in the atom count |  | ||||||
|              |  | ||||||
|             tansition_E[start_index, end_index, bond_index] += 2 |  | ||||||
|             tansition_E[end_index, start_index, bond_index] += 2 |  | ||||||
|             tansition_E_temp[start_index, end_index, bond_index] += 2 |  | ||||||
|             tansition_E_temp[end_index, start_index, bond_index] += 2 |  | ||||||
|  |  | ||||||
|         bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 |  | ||||||
|         cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118 |  | ||||||
|         cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118 |  | ||||||
|         tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1) |  | ||||||
|         assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' |  | ||||||
|      |  | ||||||
|     n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) |  | ||||||
|     n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] |  | ||||||
|  |  | ||||||
|     atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) |  | ||||||
|     print('processed meta info: ------', filename, '------') |  | ||||||
|     print('len atom_count_list', len(atom_count_list)) |  | ||||||
|     print('len atom_name_list', len(atom_name_list)) |  | ||||||
|     active_atoms = np.array(atom_name_list)[atom_count_list > 0] |  | ||||||
|     active_atoms = active_atoms.tolist() |  | ||||||
|     atom_count_list = atom_count_list.tolist() |  | ||||||
|  |  | ||||||
|     bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) |  | ||||||
|     bond_count_list = bond_count_list.tolist() |  | ||||||
|     valencies = np.array(valencies) / np.sum(valencies) |  | ||||||
|     valencies = valencies.tolist() |  | ||||||
|  |  | ||||||
|     no_edge = np.sum(tansition_E, axis=-1) == 0 |  | ||||||
|     first_elt = tansition_E[:, :, 0] |  | ||||||
|     first_elt[no_edge] = 1 |  | ||||||
|     tansition_E[:, :, 0] = first_elt |  | ||||||
|  |  | ||||||
|     tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True) |  | ||||||
|      |  | ||||||
|     meta_dict = { |  | ||||||
|         'source': source_name,  |  | ||||||
|         'num_graph': len(n_atom_list),  |  | ||||||
|         'n_atoms_per_mol_dist': n_atoms_per_mol, |  | ||||||
|         'max_node': max(n_atom_list),  |  | ||||||
|         'max_bond': max(n_bond_list),  |  | ||||||
|         'atom_type_dist': atom_count_list, |  | ||||||
|         'bond_type_dist': bond_count_list, |  | ||||||
|         'valencies': valencies, |  | ||||||
|         'active_atoms': active_atoms, |  | ||||||
|         'num_atom_type': len(active_atoms), |  | ||||||
|         'transition_E': tansition_E.tolist(), |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|     with open(f'{root}/{source_name}.meta.json', "w") as f: |  | ||||||
|         json.dump(meta_dict, f) |  | ||||||
|      |  | ||||||
|     return meta_dict |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     pass |     pass | ||||||
							
								
								
									
										47037
									
								
								graph_dit/test_nasbench.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47037
									
								
								graph_dit/test_nasbench.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user