diff --git a/graph_dit/datasets/dataset.py b/graph_dit/datasets/dataset.py index 0c12a1f..38d7520 100644 --- a/graph_dit/datasets/dataset.py +++ b/graph_dit/datasets/dataset.py @@ -21,7 +21,7 @@ from sklearn.model_selection import train_test_split import utils as utils from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from diffusion.distributions import DistributionNodes - +from nas_201_api import NASBench201API bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} class DataModule(AbstractDataModule): @@ -48,17 +48,17 @@ class DataModule(AbstractDataModule): # Load the dataset to the memory # Dataset has target property, root path, and transform dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) - print("len dataset", len(dataset)) - def print_data(dataset): - print("dataset", dataset) - print("dataset keys", dataset.keys) - print("dataset x", dataset.x) - print("dataset edge_index", dataset.edge_index) - print("dataset edge_attr", dataset.edge_attr) - print("dataset y", dataset.y) - print("") - print_data(dataset=dataset[0]) - print_data(dataset=dataset[1]) + # print("len dataset", len(dataset)) + # def print_data(dataset): + # print("dataset", dataset) + # print("dataset keys", dataset.keys) + # print("dataset x", dataset.x) + # print("dataset edge_index", dataset.edge_index) + # print("dataset edge_attr", dataset.edge_attr) + # print("dataset y", dataset.y) + # print("") + # print_data(dataset=dataset[0]) + # print_data(dataset=dataset[1]) if len(self.task.split('-')) == 2: @@ -155,7 +155,30 @@ class Dataset(InMemoryDataset): def processed_file_names(self): return [f'{self.source}.pt'] + def create_adj_matrix_and_ops(nodes, edges): + num_nodes = len(nodes) + adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int) + for (src, dst) in edges: + adj_matrix[src][dst] = 1 + return adj_matrix, nodes + + def parse_architecture_string(arch_str): + print(arch_str) + steps = arch_str.split('+') + nodes = ['input'] # Start with input node + edges = [] + for i, step in enumerate(steps): + step = step.strip('|').split('|') + for node in step: + op, idx = node.split('~') + edges.append((int(idx), i+1)) # i+1 because 0 is input node + nodes.append(op) + nodes.append('output') # Add output node + return nodes, edges + def process(self): + # return + api = NASBench201API('./NAS-Bench-201-v1_1-096897.pth') RDLogger.DisableLog('rdApp.*') data_path = osp.join(self.raw_dir, self.raw_file_names[0]) data_df = pd.read_csv(data_path) @@ -200,26 +223,65 @@ class Dataset(InMemoryDataset): return data, active_atoms # Loop through every row in the DataFrame and apply the function + # data_list = [] + + # len_data = len(data_df) + len_data = 15625 data_list = [] - len_data = len(data_df) - with tqdm(total=len_data) as pbar: - # --- data processing start --- - active_atoms = set() - for i, (sms, df_row) in enumerate(data_df.iterrows()): - if i == sms: - sms = df_row['smiles'] - mol = Chem.MolFromSmiles(sms, sanitize=False) - if len(self.target_prop.split('-')) == 2: - target1, target2 = self.target_prop.split('-') - data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) - elif len(self.target_prop.split('-')) == 3: - target1, target2, target3 = self.target_prop.split('-') - data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) - else: - data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) - active_atoms.update(cur_active_atoms) - data_list.append(data) - pbar.update(1) + bonds = { + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'input': 0, + 'output': 5 + } + + def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): + nodes, edges = Dataset.parse_architecture_string(arch_str) + node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary + x = torch.LongTensor(node_labels) + + edges_list = [(start, end) for start, end in edges] + edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type + edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous() + edge_type = torch.tensor(edge_type, dtype=torch.long) + edge_attr = edge_type.view(-1, 1) + + if target3 is not None: + y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1) + elif target2 is not None: + y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1) + else: + y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + return data, nodes + # --- data processing start --- + # active_atoms = set() + # for i, (sms, df_row) in enumerate(data_df.iterrows()): + # if i == sms: + # sms = df_row['smiles'] + # mol = Chem.MolFromSmiles(sms, sanitize=False) + # if len(self.target_prop.split('-')) == 2: + # target1, target2 = self.target_prop.split('-') + # data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) + # elif len(self.target_prop.split('-')) == 3: + # target1, target2, target3 = self.target_prop.split('-') + # data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) + # else: + # data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) + # active_atoms.update(cur_active_atoms) + # data_list.append(data) + # pbar.update(1) + for arch_index in range(len_data): + arch_info = api.get_arch(arch_index) + arch_str = arch_info['arch_str'] + + nodes, edges = Dataset.parse_architecture_string(arch_str) + adj_matrix, nodes = Dataset.create_adj_matrix_and_ops(nodes, edges) + + data, cur_active_atoms = graph torch.save(self.collate(data_list), self.processed_paths[0]) @@ -234,8 +296,10 @@ class DataInfos(AbstractDatasetInfos): 'N2': 'regression', 'CO2': 'regression', } + task_name = cfg.dataset.task_name self.task = task_name + print(self.task) self.task_type = tasktype_dict.get(task_name, "regression") self.ensure_connected = cfg.model.ensure_connected @@ -409,6 +473,181 @@ def compute_meta(root, source_name, train_index, test_index): return meta_dict +op_to_atom = { + 'input': 'Si', # Hydrogen for input + 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution + 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution + 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling + 'skip_connect': 'P', # Phosphorus for skip connection + 'none': 'S', # Sulfur for no operation + 'output': 'He' # Helium for output +} + +def get_sample_nasbench_graph(): + adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0]]) + ops = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] + return adj_mat, ops + +def nasbench_to_molecule(adj_mat, ops): + mol = Chem.RWMol() # Create a new editable molecule + atom_map = {} # Map to keep track of node to atom mapping + + # Add atoms to the molecule + for i, op in enumerate(ops): + atom_type = op_to_atom.get(op, 'C') # Default to Carbon if operation not found + atom = Chem.Atom(atom_type) # Create an atom of the specified type + idx = mol.AddAtom(atom) + atom_map[i] = idx + + # Add bonds to the molecule + for i in range(adj_mat.shape[0]): + for j in range(adj_mat.shape[1]): + if adj_mat[i, j] == 1: + mol.AddBond(atom_map[i], atom_map[j], Chem.rdchem.BondType.SINGLE) + + return mol + +def compute_meta_graph(root, source_name, train_index, test_index): + # initialize the periodic table + # 118 elements + 1 for * + # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. + pt = Chem.GetPeriodicTable() + atom_name_list = [] + atom_count_list = [] + for i in range(2, 119): + atom_name_list.append(pt.GetElementSymbol(i)) + atom_count_list.append(0) + atom_name_list.append('*') + atom_count_list.append(0) + n_atoms_per_mol = [0] * 500 + bond_count_list = [0, 0, 0, 0, 0] + bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} + valencies = [0] * 500 + tansition_E = np.zeros((118, 118, 5)) + + # Load the data from the source file + filename = f'{source_name}.csv.gz' + df = pd.read_csv(f'{root}/{filename}') + all_index = list(range(len(df))) + non_test_index = list(set(all_index) - set(test_index)) + df = df.iloc[non_test_index] + # extract the smiles from the dataframe + tot_smiles = df['smiles'].tolist() + + n_atom_list = [] + n_bond_list = [] + for i, sms in enumerate(tot_smiles): + try: + mol = Chem.MolFromSmiles(sms) + except: + continue + + n_atom = mol.GetNumHeavyAtoms() + n_bond = mol.GetNumBonds() + n_atom_list.append(n_atom) + n_bond_list.append(n_bond) + + n_atoms_per_mol[n_atom] += 1 + cur_atom_count_arr = np.zeros(118) + + for atom in mol.GetAtoms(): + symbol = atom.GetSymbol() + if symbol == 'H': + continue + elif symbol == '*': + atom_count_list[-1] += 1 + cur_atom_count_arr[-1] += 1 + else: + atom_count_list[atom.GetAtomicNum()-2] += 1 + cur_atom_count_arr[atom.GetAtomicNum()-2] += 1 + try: + valencies[int(atom.GetExplicitValence())] += 1 + except: + print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence())) + + tansition_E_temp = np.zeros((118, 118, 5)) + for bond in mol.GetBonds(): + start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H': + continue + + if start_atom.GetSymbol() == '*': + start_index = 117 + else: + start_index = start_atom.GetAtomicNum() - 2 + if end_atom.GetSymbol() == '*': + end_index = 117 + else: + end_index = end_atom.GetAtomicNum() - 2 + + bond_type = bond.GetBondType() + bond_index = bond_type_to_index[bond_type] + bond_count_list[bond_index] += 2 + + # Update the transition matrix + # The transition matrix is symmetric, so we update both directions + # We also update the temporary transition matrix to check for errors + # in the atom count + + tansition_E[start_index, end_index, bond_index] += 2 + tansition_E[end_index, start_index, bond_index] += 2 + tansition_E_temp[start_index, end_index, bond_index] += 2 + tansition_E_temp[end_index, start_index, bond_index] += 2 + + bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 + cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118 + cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118 + tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1) + assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' + + n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) + n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] + + atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) + print('processed meta info: ------', filename, '------') + print('len atom_count_list', len(atom_count_list)) + print('len atom_name_list', len(atom_name_list)) + active_atoms = np.array(atom_name_list)[atom_count_list > 0] + active_atoms = active_atoms.tolist() + atom_count_list = atom_count_list.tolist() + + bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) + bond_count_list = bond_count_list.tolist() + valencies = np.array(valencies) / np.sum(valencies) + valencies = valencies.tolist() + + no_edge = np.sum(tansition_E, axis=-1) == 0 + first_elt = tansition_E[:, :, 0] + first_elt[no_edge] = 1 + tansition_E[:, :, 0] = first_elt + + tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True) + + meta_dict = { + 'source': source_name, + 'num_graph': len(n_atom_list), + 'n_atoms_per_mol_dist': n_atoms_per_mol, + 'max_node': max(n_atom_list), + 'max_bond': max(n_bond_list), + 'atom_type_dist': atom_count_list, + 'bond_type_dist': bond_count_list, + 'valencies': valencies, + 'active_atoms': active_atoms, + 'num_atom_type': len(active_atoms), + 'transition_E': tansition_E.tolist(), + } + + with open(f'{root}/{source_name}.meta.json', "w") as f: + json.dump(meta_dict, f) + + return meta_dict if __name__ == "__main__": pass \ No newline at end of file