try to transfer the code from jupyter notebook to dataset.py
This commit is contained in:
		| @@ -2,6 +2,8 @@ | ||||
| import sys | ||||
| sys.path.append('../')  | ||||
|  | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
| import os | ||||
| import os.path as osp | ||||
| import pathlib | ||||
| @@ -24,7 +26,266 @@ from diffusion.distributions import DistributionNodes | ||||
|  | ||||
| bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|  | ||||
| op_to_atom = { | ||||
|     'input': 'Si',          # Hydrogen for input | ||||
|     'nor_conv_1x1': 'C',   # Carbon for 1x1 convolution | ||||
|     'nor_conv_3x3': 'N',   # Nitrogen for 3x3 convolution | ||||
|     'avg_pool_3x3': 'O',   # Oxygen for 3x3 average pooling | ||||
|     'skip_connect': 'P',   # Phosphorus for skip connection | ||||
|     'none': 'S',           # Sulfur for no operation | ||||
|     'output': 'He'         # Helium for output | ||||
| } | ||||
| class DataModule(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         self.datadir = cfg.dataset.datadir | ||||
|         self.task = cfg.dataset.task_name | ||||
|         print("DataModule") | ||||
|         print("task", self.task) | ||||
|         print("datadir", self.datadir) | ||||
|         super().__init__(cfg) | ||||
|  | ||||
|     def prepare_data(self) -> None: | ||||
|         target = getattr(self.cfg.dataset, 'guidance_target', None) | ||||
|         print("target", target) | ||||
|         # try: | ||||
|         #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         # except NameError: | ||||
|         # base_path = pathlib.Path(os.getcwd()).parent[2] | ||||
|         base_path = '/home/stud/hanzhang/Graph-Dit' | ||||
|         root_path = os.path.join(base_path, self.datadir) | ||||
|         self.root_path = root_path | ||||
|  | ||||
|         batch_size = self.cfg.train.batch_size | ||||
|          | ||||
|         num_workers = self.cfg.train.num_workers | ||||
|         pin_memory = self.cfg.dataset.pin_memory | ||||
|  | ||||
|         # Load the dataset to the memory | ||||
|         # Dataset has target property, root path, and transform | ||||
|         source = './NAS-Bench-201-v1_1-096897.pth' | ||||
|         dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) | ||||
|  | ||||
|         # if len(self.task.split('-')) == 2: | ||||
|         #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
|         # else: | ||||
|         train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) | ||||
|  | ||||
|         self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index | ||||
|         train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) | ||||
|         if len(unlabeled_index) > 0: | ||||
|             train_index = torch.cat([train_index, unlabeled_index], dim=0) | ||||
|          | ||||
|         train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] | ||||
|         self.train_dataset = train_dataset   | ||||
|         print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
|         print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) | ||||
|         self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) | ||||
|  | ||||
|         self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|         self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) | ||||
|  | ||||
|         training_iterations = len(train_dataset) // batch_size | ||||
|         self.training_iterations = training_iterations | ||||
|      | ||||
|     def random_data_split(self, dataset): | ||||
|         nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() | ||||
|         labeled_len = len(dataset) - nan_count | ||||
|         full_idx = list(range(labeled_len)) | ||||
|         train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 | ||||
|         train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) | ||||
|         train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) | ||||
|         unlabeled_index = list(range(labeled_len, len(dataset))) | ||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) | ||||
|         return train_index, val_index, test_index, unlabeled_index | ||||
|      | ||||
|     def fixed_split(self, dataset): | ||||
|         if self.task == 'O2-N2': | ||||
|             test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] | ||||
|         else: | ||||
|             raise ValueError('Invalid task name: {}'.format(self.task)) | ||||
|         full_idx = list(range(len(dataset))) | ||||
|         full_idx = list(set(full_idx) - set(test_index)) | ||||
|         train_ratio = 0.8 | ||||
|         train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) | ||||
|         print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) | ||||
|         return train_index, val_index, test_index, [] | ||||
|  | ||||
|     def get_train_smiles(self): | ||||
|         raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") | ||||
|  | ||||
|     def get_data_split(self): | ||||
|         raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") | ||||
|  | ||||
|     def example_batch(self): | ||||
|         return next(iter(self.val_loader)) | ||||
|      | ||||
|     def train_dataloader(self): | ||||
|         return self.train_loader | ||||
|  | ||||
|     def val_dataloader(self): | ||||
|         return self.val_loader | ||||
|      | ||||
|     def test_dataloader(self): | ||||
|         return self.test_loader | ||||
|  | ||||
| def graphs_to_json(graphs, filename): | ||||
|     bonds = { | ||||
|         'nor_conv_1x1': 1, | ||||
|         'nor_conv_3x3': 2, | ||||
|         'avg_pool_3x3': 3, | ||||
|         'skip_connect': 4, | ||||
|         'input': 7, | ||||
|         'output': 5, | ||||
|         'none': 6 | ||||
|     } | ||||
|  | ||||
|     source_name = "nas-bench-201" | ||||
|     num_graph = len(graphs) | ||||
|     pt = Chem.GetPeriodicTable() | ||||
|     atom_name_list = [] | ||||
|     atom_count_list = [] | ||||
|     for i in range(2, 119): | ||||
|         atom_name_list.append(pt.GetElementSymbol(i)) | ||||
|         atom_count_list.append(0) | ||||
|     atom_name_list.append('*') | ||||
|     atom_count_list.append(0) | ||||
|     n_atoms_per_mol = [0] * 500 | ||||
|     bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] | ||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|     valencies = [0] * 500 | ||||
|     transition_E = np.zeros((118, 118, 5)) | ||||
|  | ||||
|     n_atom_list = [] | ||||
|     n_bond_list = [] | ||||
|     # graphs = [(adj_matrix, ops), ...] | ||||
|     for graph in graphs: | ||||
|         ops = graph[1] | ||||
|         adj = graph[0] | ||||
|         n_atom = len(ops) | ||||
|         n_bond = len(ops) | ||||
|         n_atom_list.append(n_atom) | ||||
|         n_bond_list.append(n_bond) | ||||
|  | ||||
|         n_atoms_per_mol[n_atom] += 1 | ||||
|         cur_atom_count_arr = np.zeros(118) | ||||
|  | ||||
|         for op in ops: | ||||
|             symbol = op_to_atom[op] | ||||
|             if symbol == 'H': | ||||
|                 continue | ||||
|             elif symbol == '*': | ||||
|                 atom_count_list[-1] += 1 | ||||
|                 cur_atom_count_arr[-1] += 1 | ||||
|             else: | ||||
|                 atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 # print('symbol', symbol) | ||||
|                 # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) | ||||
|                 # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') | ||||
|                 try: | ||||
|                     valencies[int(pt.GetDefaultValence(symbol))] += 1 | ||||
|                 except: | ||||
|                     print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) | ||||
|         transition_E_temp = np.zeros((118, 118, 5)) | ||||
|         # print(n_atom) | ||||
|         for i in range(n_atom): | ||||
|             for j in range(n_atom): | ||||
|                 if i == j or adj[i][j] == 0: | ||||
|                     continue | ||||
|                 start_atom, end_atom = i, j | ||||
|                 if ops[start_atom] == 'input' or ops[end_atom] == 'input': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'output' or ops[end_atom] == 'output': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'none' or ops[end_atom] == 'none': | ||||
|                     continue | ||||
|                  | ||||
|                 start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 | ||||
|                 end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 | ||||
|                 bond_index = bonds[ops[end_atom]] | ||||
|                 bond_count_list[bond_index] += 2 | ||||
|  | ||||
|                 # print(start_index, end_index, bond_index) | ||||
|                  | ||||
|                 transition_E[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E[end_index, start_index, bond_index] += 2 | ||||
|                 transition_E_temp[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E_temp[end_index, start_index, bond_index] += 2 | ||||
|  | ||||
|         bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 | ||||
|         print(bond_count_list) | ||||
|         cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 | ||||
|         # print(f'cur_tot_bond={cur_tot_bond}')    | ||||
|         # find non-zero element in cur_tot_bond | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if cur_tot_bond[i][j] != 0: | ||||
|         #             print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') | ||||
|         # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|         cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 | ||||
|         # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") | ||||
|         transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) | ||||
|         # find non-zero element in transition_E | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if transition_E[i][j][0] != 0: | ||||
|         #             print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|         assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' | ||||
|      | ||||
|     n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|     n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] | ||||
|  | ||||
|     atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) | ||||
|     print('processed meta info: ------', filename, '------') | ||||
|     print('len atom_count_list', len(atom_count_list)) | ||||
|     print('len atom_name_list', len(atom_name_list)) | ||||
|     active_atoms = np.array(atom_name_list)[atom_count_list > 0] | ||||
|     active_atoms = active_atoms.tolist() | ||||
|     atom_count_list = atom_count_list.tolist() | ||||
|  | ||||
|     bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) | ||||
|     bond_count_list = bond_count_list.tolist() | ||||
|     valencies = np.array(valencies) / np.sum(valencies) | ||||
|     valencies = valencies.tolist() | ||||
|  | ||||
|     no_edge = np.sum(transition_E, axis=-1) == 0 | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if no_edge[i][j] == False: | ||||
|                 print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') | ||||
|     # print(f'no_edge: {no_edge}') | ||||
|     first_elt = transition_E[:, :, 0] | ||||
|     first_elt[no_edge] = 1 | ||||
|     transition_E[:, :, 0] = first_elt | ||||
|  | ||||
|     transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) | ||||
|  | ||||
|     # find non-zero element in transition_E again | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: | ||||
|                 print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|  | ||||
|     meta_dict = { | ||||
|         'source': 'nasbench-201', | ||||
|         'num_graph': num_graph, | ||||
|         'n_atoms_per_mol_dist': n_atoms_per_mol[:51], | ||||
|         'max_node': max(n_atom_list), | ||||
|         'max_bond': max(n_bond_list), | ||||
|         'atom_type_dist': atom_count_list, | ||||
|         'bond_type_dist': bond_count_list, | ||||
|         'valencies': valencies, | ||||
|         'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], | ||||
|         'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), | ||||
|         'transition_E': transition_E.tolist(), | ||||
|     } | ||||
|  | ||||
|     with open(f'{filename}.meta.json', 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|     return meta_dict | ||||
|  | ||||
| class DataModule_original(AbstractDataModule): | ||||
|     def __init__(self, cfg): | ||||
|         self.datadir = cfg.dataset.datadir | ||||
|         self.task = cfg.dataset.task_name | ||||
| @@ -48,18 +309,6 @@ class DataModule(AbstractDataModule): | ||||
|         # Load the dataset to the memory | ||||
|         # Dataset has target property, root path, and transform | ||||
|         dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) | ||||
|         print("len dataset", len(dataset)) | ||||
|         def print_data(dataset): | ||||
|             print("dataset", dataset) | ||||
|             print("dataset keys", dataset.keys) | ||||
|             print("dataset x", dataset.x) | ||||
|             print("dataset edge_index", dataset.edge_index) | ||||
|             print("dataset edge_attr", dataset.edge_attr) | ||||
|             print("dataset y", dataset.y) | ||||
|             print("") | ||||
|         print_data(dataset=dataset[0]) | ||||
|         print_data(dataset=dataset[1]) | ||||
|  | ||||
|  | ||||
|         if len(self.task.split('-')) == 2: | ||||
|             train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) | ||||
| @@ -138,8 +387,163 @@ class DataModule(AbstractDataModule): | ||||
|     def test_dataloader(self): | ||||
|         return self.test_loader | ||||
|  | ||||
| def graphs_to_json(graphs, filename): | ||||
|     bonds = { | ||||
|         'nor_conv_1x1': 1, | ||||
|         'nor_conv_3x3': 2, | ||||
|         'avg_pool_3x3': 3, | ||||
|         'skip_connect': 4, | ||||
|         'input': 7, | ||||
|         'output': 5, | ||||
|         'none': 6 | ||||
|     } | ||||
|  | ||||
| class Dataset(InMemoryDataset): | ||||
|     source_name = "nas-bench-201" | ||||
|     num_graph = len(graphs) | ||||
|     pt = Chem.GetPeriodicTable() | ||||
|     atom_name_list = [] | ||||
|     atom_count_list = [] | ||||
|     for i in range(2, 119): | ||||
|         atom_name_list.append(pt.GetElementSymbol(i)) | ||||
|         atom_count_list.append(0) | ||||
|     atom_name_list.append('*') | ||||
|     atom_count_list.append(0) | ||||
|     n_atoms_per_mol = [0] * 500 | ||||
|     bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] | ||||
|     bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} | ||||
|     valencies = [0] * 500 | ||||
|     transition_E = np.zeros((118, 118, 5)) | ||||
|  | ||||
|     n_atom_list = [] | ||||
|     n_bond_list = [] | ||||
|     # graphs = [(adj_matrix, ops), ...] | ||||
|     for graph in graphs: | ||||
|         ops = graph[1] | ||||
|         adj = graph[0] | ||||
|         n_atom = len(ops) | ||||
|         n_bond = len(ops) | ||||
|         n_atom_list.append(n_atom) | ||||
|         n_bond_list.append(n_bond) | ||||
|  | ||||
|         n_atoms_per_mol[n_atom] += 1 | ||||
|         cur_atom_count_arr = np.zeros(118) | ||||
|  | ||||
|         for op in ops: | ||||
|             symbol = op_to_atom[op] | ||||
|             if symbol == 'H': | ||||
|                 continue | ||||
|             elif symbol == '*': | ||||
|                 atom_count_list[-1] += 1 | ||||
|                 cur_atom_count_arr[-1] += 1 | ||||
|             else: | ||||
|                 atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 | ||||
|                 # print('symbol', symbol) | ||||
|                 # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) | ||||
|                 # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') | ||||
|                 try: | ||||
|                     valencies[int(pt.GetDefaultValence(symbol))] += 1 | ||||
|                 except: | ||||
|                     print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) | ||||
|         transition_E_temp = np.zeros((118, 118, 5)) | ||||
|         # print(n_atom) | ||||
|         for i in range(n_atom): | ||||
|             for j in range(n_atom): | ||||
|                 if i == j or adj[i][j] == 0: | ||||
|                     continue | ||||
|                 start_atom, end_atom = i, j | ||||
|                 if ops[start_atom] == 'input' or ops[end_atom] == 'input': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'output' or ops[end_atom] == 'output': | ||||
|                     continue | ||||
|                 if ops[start_atom] == 'none' or ops[end_atom] == 'none': | ||||
|                     continue | ||||
|                  | ||||
|                 start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 | ||||
|                 end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 | ||||
|                 bond_index = bonds[ops[end_atom]] | ||||
|                 bond_count_list[bond_index] += 2 | ||||
|  | ||||
|                 # print(start_index, end_index, bond_index) | ||||
|                  | ||||
|                 transition_E[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E[end_index, start_index, bond_index] += 2 | ||||
|                 transition_E_temp[start_index, end_index, bond_index] += 2 | ||||
|                 transition_E_temp[end_index, start_index, bond_index] += 2 | ||||
|  | ||||
|         bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 | ||||
|         print(bond_count_list) | ||||
|         cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 | ||||
|         # print(f'cur_tot_bond={cur_tot_bond}')    | ||||
|         # find non-zero element in cur_tot_bond | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if cur_tot_bond[i][j] != 0: | ||||
|         #             print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') | ||||
|         # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|         cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 | ||||
|         # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") | ||||
|         transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) | ||||
|         # find non-zero element in transition_E | ||||
|         # for i in range(118): | ||||
|         #     for j in range(118): | ||||
|         #         if transition_E[i][j][0] != 0: | ||||
|         #             print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|         assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' | ||||
|      | ||||
|     n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) | ||||
|     n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] | ||||
|  | ||||
|     atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) | ||||
|     print('processed meta info: ------', filename, '------') | ||||
|     print('len atom_count_list', len(atom_count_list)) | ||||
|     print('len atom_name_list', len(atom_name_list)) | ||||
|     active_atoms = np.array(atom_name_list)[atom_count_list > 0] | ||||
|     active_atoms = active_atoms.tolist() | ||||
|     atom_count_list = atom_count_list.tolist() | ||||
|  | ||||
|     bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) | ||||
|     bond_count_list = bond_count_list.tolist() | ||||
|     valencies = np.array(valencies) / np.sum(valencies) | ||||
|     valencies = valencies.tolist() | ||||
|  | ||||
|     no_edge = np.sum(transition_E, axis=-1) == 0 | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if no_edge[i][j] == False: | ||||
|                 print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') | ||||
|     # print(f'no_edge: {no_edge}') | ||||
|     first_elt = transition_E[:, :, 0] | ||||
|     first_elt[no_edge] = 1 | ||||
|     transition_E[:, :, 0] = first_elt | ||||
|  | ||||
|     transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) | ||||
|  | ||||
|     # find non-zero element in transition_E again | ||||
|     for i in range(118): | ||||
|         for j in range(118): | ||||
|             if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: | ||||
|                 print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') | ||||
|  | ||||
|     meta_dict = { | ||||
|         'source': 'nasbench-201', | ||||
|         'num_graph': num_graph, | ||||
|         'n_atoms_per_mol_dist': n_atoms_per_mol[:51], | ||||
|         'max_node': max(n_atom_list), | ||||
|         'max_bond': max(n_bond_list), | ||||
|         'atom_type_dist': atom_count_list, | ||||
|         'bond_type_dist': bond_count_list, | ||||
|         'valencies': valencies, | ||||
|         'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], | ||||
|         'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), | ||||
|         'transition_E': transition_E.tolist(), | ||||
|     } | ||||
|  | ||||
|     with open(f'{filename}.meta.json', 'w') as f: | ||||
|         json.dump(meta_dict, f) | ||||
|     return meta_dict | ||||
|  | ||||
| class Dataset_origin(InMemoryDataset): | ||||
|     def __init__(self, source, root, target_prop=None, | ||||
|                  transform=None, pre_transform=None, pre_filter=None): | ||||
|         self.target_prop = target_prop | ||||
| @@ -223,8 +627,95 @@ class Dataset(InMemoryDataset): | ||||
|  | ||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||
|  | ||||
| def parse_architecture_string(arch_str): | ||||
|     print(arch_str) | ||||
|     steps = arch_str.split('+') | ||||
|     nodes = ['input']  # Start with input node | ||||
|     edges = [] | ||||
|     for i, step in enumerate(steps): | ||||
|         step = step.strip('|').split('|') | ||||
|         for node in step: | ||||
|             op, idx = node.split('~') | ||||
|             edges.append((int(idx), i+1))  # i+1 because 0 is input node | ||||
|             nodes.append(op) | ||||
|     nodes.append('output')  # Add output node | ||||
|     return nodes, edges | ||||
|  | ||||
| def create_adj_matrix_and_ops(nodes, edges): | ||||
|     num_nodes = len(nodes) | ||||
|     adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int) | ||||
|     for (src, dst) in edges: | ||||
|         adj_matrix[src][dst] = 1 | ||||
|     return adj_matrix, nodes | ||||
| class DataInfos(AbstractDatasetInfos): | ||||
|     def __init__(self, datamodule, cfg): | ||||
|         tasktype_dict = { | ||||
|             'hiv_b': 'classification', | ||||
|             'bace_b': 'classification', | ||||
|             'bbbp_b': 'classification', | ||||
|             'O2': 'regression', | ||||
|             'N2': 'regression', | ||||
|             'CO2': 'regression', | ||||
|         } | ||||
|         task_name = cfg.dataset.task_name | ||||
|         self.task = task_name | ||||
|         self.task_type = tasktype_dict.get(task_name, "regression") | ||||
|         self.ensure_connected = cfg.model.ensure_connected | ||||
|  | ||||
|         datadir = cfg.dataset.datadir | ||||
|  | ||||
|         base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] | ||||
|         meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') | ||||
|         data_root = os.path.join(base_path, datadir, 'raw') | ||||
|         graphs = [] | ||||
|         length = 15625 | ||||
|         ops_type = {} | ||||
|         len_ops = set() | ||||
|         api = API('../NAS-Bench-201-v1_0-e61699.pth') | ||||
|         for i in range(length): | ||||
|             arch_info = api.query_meta_info_by_index(i) | ||||
|             nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|             adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)     | ||||
|             if i < 5: | ||||
|                 print("Adjacency Matrix:") | ||||
|                 print(adj_matrix) | ||||
|                 print("Operations List:") | ||||
|                 print(ops) | ||||
|             for op in ops: | ||||
|                 if op not in ops_type: | ||||
|                     ops_type[op] = len(ops_type) | ||||
|             len_ops.add(len(ops)) | ||||
|             graphs.append((adj_matrix, ops)) | ||||
|  | ||||
|         meta_dict = graphs_to_json(graphs, 'nasbench-201') | ||||
|  | ||||
|         self.base_path = base_path | ||||
|         self.active_atoms = meta_dict['active_atoms'] | ||||
|         self.max_n_nodes = meta_dict['max_node'] | ||||
|         self.original_max_n_nodes = meta_dict['max_node'] | ||||
|         self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) | ||||
|         self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) | ||||
|         self.transition_E = torch.Tensor(meta_dict['transition_E']) | ||||
|  | ||||
|         self.atom_decoder = meta_dict['active_atoms'] | ||||
|         node_types = torch.Tensor(meta_dict['atom_type_dist']) | ||||
|         active_index = (node_types > 0).nonzero().squeeze() | ||||
|         self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] | ||||
|         self.nodes_dist = DistributionNodes(self.n_nodes) | ||||
|         self.active_index = active_index | ||||
|  | ||||
|         val_len = 3 * self.original_max_n_nodes - 2 | ||||
|         meta_val = torch.Tensor(meta_dict['valencies']) | ||||
|         self.valency_distribution = torch.zeros(val_len) | ||||
|         val_len = min(val_len, len(meta_val)) | ||||
|         self.valency_distribution[:val_len] = meta_val[:val_len] | ||||
|         self.y_prior = None | ||||
|         self.train_ymin = [] | ||||
|         self.train_ymax = [] | ||||
|  | ||||
|  | ||||
|  | ||||
| class DataInfos_origin(AbstractDatasetInfos): | ||||
|     def __init__(self, datamodule, cfg): | ||||
|         tasktype_dict = { | ||||
|             'hiv_b': 'classification', | ||||
|   | ||||
		Reference in New Issue
	
	Block a user