import sys sys.path.append('../') from nas_201_api import NASBench201API as API import os import os.path as osp import pathlib import json import random import torch import torch.nn.functional as F from rdkit import Chem, RDLogger from rdkit.Chem.rdchem import BondType as BT from rdkit.Chem import rdchem from tqdm import tqdm import numpy as np import pandas as pd from torch_geometric.data import Data, InMemoryDataset from torch_geometric.loader import DataLoader from sklearn.model_selection import train_test_split import utils as utils from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from diffusion.distributions import DistributionNodes from naswot.score_networks import get_nasbench201_idx_score from naswot import nasspace from naswot import datasets as dt import networkx as nx bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} op_to_atom = { 'input': 'Si', # Hydrogen for input 'nor_conv_1x1': 'C', # Carbon for 1x1 convolution 'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution 'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling 'skip_connect': 'P', # Phosphorus for skip connection 'none': 'S', # Sulfur for no operation 'output': 'He' # Helium for output } op_type = { 'input': 0, 'nor_conv_1x1': 1, 'nor_conv_3x3': 2, 'avg_pool_3x3': 3, 'skip_connect': 4, 'none': 5, 'output': 6, } num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] class DataModule(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir self.task = cfg.dataset.task_name print("DataModule") print("task", self.task) print("datadir", self.datadir) super().__init__(cfg) def prepare_data(self) -> None: target = getattr(self.cfg.dataset, 'guidance_target', None) print("target", target) # nasbench-201 # try: # base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # except NameError: # base_path = pathlib.Path(os.getcwd()).parent[2] base_path = '/nfs/data3/hanzhang/nasbenchDiT' root_path = os.path.join(base_path, self.datadir) self.root_path = root_path batch_size = self.cfg.train.batch_size num_workers = self.cfg.train.num_workers pin_memory = self.cfg.dataset.pin_memory # Load the dataset to the memory # Dataset has target property, root path, and transform source = './NAS-Bench-201-v1_1-096897.pth' dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None) self.dataset = dataset # self.api = dataset.api # if len(self.task.split('-')) == 2: # train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) # else: train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) self.train_index, self.val_index, self.test_index, self.unlabeled_index = ( train_index, val_index, test_index, unlabeled_index) train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) if len(unlabeled_index) > 0: train_index = torch.cat([train_index, unlabeled_index], dim=0) train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] self.train_dataset = train_dataset self.test_dataset = test_dataset print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) training_iterations = len(train_dataset) // batch_size self.training_iterations = training_iterations def random_data_split(self, dataset): # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() # labeled_len = len(dataset) - nan_count labeled_len = len(dataset) full_idx = list(range(labeled_len)) train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) unlabeled_index = list(range(labeled_len, len(dataset))) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) return train_index, val_index, test_index, unlabeled_index def fixed_split(self, dataset): if self.task == 'O2-N2': test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] else: raise ValueError('Invalid task name: {}'.format(self.task)) full_idx = list(range(len(dataset))) full_idx = list(set(full_idx) - set(test_index)) train_ratio = 0.8 train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) return train_index, val_index, test_index, [] # def parse_architecture_string(self, arch_str): # stages = arch_str.split('+') # nodes = ['input'] # edges = [] # for stage in stages: # operations = stage.strip('|').split('|') # for op in operations: # operation, idx = op.split('~') # idx = int(idx) # edges.append((idx, len(nodes))) # Add edge from idx to the new node # nodes.append(operation) # nodes.append('output') # Add the output node # return nodes, edges def parse_architecture_string(arch_str): # print(arch_str) steps = arch_str.split('+') nodes = ['input'] # Start with input node adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 1 ,0 ,0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0]]) steps = arch_str.split('+') steps_coding = ['0', '0', '1', '0', '1', '2'] cont = 0 for step in steps: step = step.strip('|').split('|') for node in step: n, idx = node.split('~') assert idx == steps_coding[cont] cont += 1 nodes.append(n) nodes.append('output') # Add output node return nodes, adj_mat # def create_molecule_from_graph(nodes, edges): def create_molecule_from_graph(self, graph): nodes = graph.x edges = graph.edge_index mol = Chem.RWMol() # RWMol allows for building the molecule step by step atom_indices = {} num_to_op = { 1 :'nor_conv_1x1', 2 :'nor_conv_3x3', 3 :'avg_pool_3x3', 4 :'skip_connect', 5 :'output', 6 :'none', 7 :'input' } # Extract node operations from the data object # Add atoms to the molecule for i, op_tensor in enumerate(nodes): op = op_tensor.item() if op == 0: continue op = num_to_op[op] atom_symbol = op_to_atom[op] atom = Chem.Atom(atom_symbol) atom_idx = mol.AddAtom(atom) atom_indices[i] = atom_idx # Add bonds to the molecule edge_number = edges.shape[1] for i in range(edge_number): start = edges[0, i].item() end = edges[1, i].item() mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE) return mol # def arch_str_to_smiles(self, arch_str): # nodes, edges = self.parse_architecture_string(arch_str) # mol = self.create_molecule_from_graph(nodes, edges) # smiles = Chem.MolToSmiles(mol) # return smiles def get_train_graphs(self): train_graphs = [] test_graphs = [] for graph in self.train_dataset: train_graphs.append(graph) for graph in self.test_dataset: test_graphs.append(graph) return train_graphs, test_graphs # def get_train_smiles(self): # filename = f'{self.task}.csv.gz' # df = pd.read_csv(f'{self.root_path}/raw/{filename}') # df_test = df.iloc[self.test_index] # df = df.iloc[self.train_index] # smiles_list = df['smiles'].tolist() # smiles_list_test = df_test['smiles'].tolist() # smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] # smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] # return smiles_list, smiles_list_test def get_train_smiles(self): train_smiles = [] test_smiles = [] for graph in self.train_dataset: # print(f'idx={idx}') # graph = self.train_dataset[idx] print(graph.x) print(graph.edge_index) print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}') mol = self.create_molecule_from_graph(graph) train_smiles.append(Chem.MolToSmiles(mol)) # for idx in self.test_index: for graph in self.test_dataset: # graph = self.dataset[idx] # mol = self.create_molecule_from_graph(graph.x, graph.edge_index) mol = self.create_molecule_from_graph(graph) test_smiles.append(Chem.MolToSmiles(mol)) # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs] # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs] return train_smiles, test_smiles def get_data_split(self): raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") def example_batch(self): return next(iter(self.val_loader)) def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader def test_dataloader(self): return self.test_loader class DataModule_original(AbstractDataModule): def __init__(self, cfg): self.datadir = cfg.dataset.datadir self.task = cfg.dataset.task_name print("DataModule") print("task", self.task) print("datadir`",self.datadir) super().__init__(cfg) def prepare_data(self) -> None: target = getattr(self.cfg.dataset, 'guidance_target', None) print("target", target) base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] root_path = os.path.join(base_path, self.datadir) self.root_path = root_path batch_size = self.cfg.train.batch_size num_workers = self.cfg.train.num_workers pin_memory = self.cfg.dataset.pin_memory # Load the dataset to the memory # Dataset has target property, root path, and transform dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None) if len(self.task.split('-')) == 2: train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset) else: train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset) self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index) if len(unlabeled_index) > 0: train_index = torch.cat([train_index, unlabeled_index], dim=0) train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] self.train_dataset = train_dataset print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory) self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False) training_iterations = len(train_dataset) // batch_size self.training_iterations = training_iterations def random_data_split(self, dataset): nan_count = torch.isnan(dataset.y[:, 0]).sum().item() labeled_len = len(dataset) - nan_count full_idx = list(range(labeled_len)) train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42) unlabeled_index = list(range(labeled_len, len(dataset))) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index)) return train_index, val_index, test_index, unlabeled_index def fixed_split(self, dataset): if self.task == 'O2-N2': test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604] else: raise ValueError('Invalid task name: {}'.format(self.task)) full_idx = list(range(len(dataset))) full_idx = list(set(full_idx) - set(test_index)) train_ratio = 0.8 train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) return train_index, val_index, test_index, [] def get_train_smiles(self): filename = f'{self.task}.csv.gz' df = pd.read_csv(f'{self.root_path}/raw/{filename}') df_test = df.iloc[self.test_index] df = df.iloc[self.train_index] smiles_list = df['smiles'].tolist() smiles_list_test = df_test['smiles'].tolist() smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list] smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test] return smiles_list, smiles_list_test def get_data_split(self): filename = f'{self.task}.csv.gz' df = pd.read_csv(f'{self.root_path}/raw/{filename}') df_val = df.iloc[self.val_index] df_test = df.iloc[self.test_index] df_train = df.iloc[self.train_index] return df_train, df_val, df_test def example_batch(self): return next(iter(self.val_loader)) def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader def test_dataloader(self): return self.test_loader def new_graphs_to_json(graphs, filename): source_name = "nasbench-201" num_graph = len(graphs) node_name_list = [] node_count_list = [] node_name_list.append('*') for op_name in op_type: node_name_list.append(op_name) node_count_list.append(0) node_count_list.append(0) n_nodes_per_graph = [0] * num_graph edge_count_list = [0, 0] valencies = [0] * (len(op_type) + 1) transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) n_node_list = [] n_edge_list = [] for graph in graphs: ops = graph[1] adj = graph[0] n_node = len(ops) print(n_node) n_edge = len(ops) n_node_list.append(n_node) n_edge_list.append(n_edge) n_nodes_per_graph[n_node] += 1 cur_node_count_arr = np.zeros(len(op_type) + 1) for op in ops: node = op # if node == '*': # node_count_list[-1] += 1 # cur_node_count_arr[-1] += 1 # else: node_count_list[node] += 1 cur_node_count_arr[node] += 1 try: valencies[node] += 1 except: print('int(op_type[node])', int(node)) transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2)) for i in range(n_node): for j in range(n_node): if i == j or adj[i][j] == 0: continue start_node, end_node = i, j start_index = ops[start_node] end_index = ops[end_node] bond_index = 1 edge_count_list[bond_index] += 2 transition_E[start_index, end_index, bond_index] += 2 transition_E[end_index, start_index, bond_index] += 2 transition_E_temp[start_index, end_index, bond_index] += 2 transition_E_temp[end_index, start_index, bond_index] += 2 edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2 cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2 # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}") cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2 transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1) assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0 n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph) n_nodes_per_graph = n_nodes_per_graph.tolist()[:51] node_count_list = np.array(node_count_list) / np.sum(node_count_list) print('processed meta info: ------', filename, '------') print('len node_count_list', len(node_count_list)) print('len node_name_list', len(node_name_list)) active_nodes = np.array(node_name_list)[node_count_list > 0] active_nodes = active_nodes.tolist() node_count_list = node_count_list.tolist() edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list) edge_count_list = edge_count_list.tolist() valencies = np.array(valencies) / np.sum(valencies) valencies = valencies.tolist() no_edge = np.sum(transition_E, axis=-1) == 0 first_elt = transition_E[:, :, 0] first_elt[no_edge] = 1 transition_E[:, :, 0] = first_elt transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) meta_dict = { 'source': source_name, 'num_graph': num_graph, 'n_nodes_per_graph': n_nodes_per_graph, 'max_n_nodes': max(n_node_list), 'max_n_edges': max(n_edge_list), 'node_type_list': node_count_list, 'edge_type_list': edge_count_list, 'valencies': valencies, 'active_nodes': active_nodes, 'num_active_nodes': len(active_nodes), 'transition_E': transition_E.tolist(), } with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f: json.dump(meta_dict, f) return meta_dict def graphs_to_json(graphs, filename): bonds = { 'nor_conv_1x1': 1, 'nor_conv_3x3': 2, 'avg_pool_3x3': 3, 'skip_connect': 4, 'input': 7, 'output': 5, 'none': 6 } source_name = "nas-bench-201" num_graph = len(graphs) pt = Chem.GetPeriodicTable() atom_name_list = [] atom_count_list = [] for i in range(2, 119): atom_name_list.append(pt.GetElementSymbol(i)) atom_count_list.append(0) atom_name_list.append('*') atom_count_list.append(0) n_atoms_per_mol = [0] * 500 bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} valencies = [0] * 500 transition_E = np.zeros((118, 118, 8)) n_atom_list = [] n_bond_list = [] # graphs = [(adj_matrix, ops), ...] for graph in graphs: ops = graph[1] adj = graph[0] n_atom = len(ops) n_bond = len(ops) n_atom_list.append(n_atom) n_bond_list.append(n_bond) n_atoms_per_mol[n_atom] += 1 cur_atom_count_arr = np.zeros(118) for op in ops: symbol = op_to_atom[op] if symbol == 'H': continue elif symbol == '*': atom_count_list[-1] += 1 cur_atom_count_arr[-1] += 1 else: atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1 cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1 # print('symbol', symbol) # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol)) # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}') try: valencies[int(pt.GetDefaultValence(symbol))] += 1 except: print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) transition_E_temp = np.zeros((118, 118, 8)) # print(n_atom) for i in range(n_atom): for j in range(n_atom): if i == j or adj[i][j] == 0: continue start_atom, end_atom = i, j if ops[start_atom] == 'input' or ops[end_atom] == 'input': continue if ops[start_atom] == 'output' or ops[end_atom] == 'output': continue if ops[start_atom] == 'none' or ops[end_atom] == 'none': continue start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2 end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2 bond_index = bonds[ops[end_atom]] bond_count_list[bond_index] += 2 # print(start_index, end_index, bond_index) transition_E[start_index, end_index, bond_index] += 2 transition_E[end_index, start_index, bond_index] += 2 transition_E_temp[start_index, end_index, bond_index] += 2 transition_E_temp[end_index, start_index, bond_index] += 2 bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 print(bond_count_list) cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # print(f'cur_tot_bond={cur_tot_bond}') # find non-zero element in cur_tot_bond # for i in range(118): # for j in range(118): # if cur_tot_bond[i][j] != 0: # print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}') # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}") transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1) # find non-zero element in transition_E # for i in range(118): # for j in range(118): # if transition_E[i][j][0] != 0: # print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}') assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) print('processed meta info: ------', filename, '------') print('len atom_count_list', len(atom_count_list)) print('len atom_name_list', len(atom_name_list)) active_atoms = np.array(atom_name_list)[atom_count_list > 0] active_atoms = active_atoms.tolist() atom_count_list = atom_count_list.tolist() bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) bond_count_list = bond_count_list.tolist() valencies = np.array(valencies) / np.sum(valencies) valencies = valencies.tolist() no_edge = np.sum(transition_E, axis=-1) == 0 for i in range(118): for j in range(118): if no_edge[i][j] == False: print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}') # print(f'no_edge: {no_edge}') first_elt = transition_E[:, :, 0] first_elt[no_edge] = 1 transition_E[:, :, 0] = first_elt transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True) # find non-zero element in transition_E again for i in range(118): for j in range(118): if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1: print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}') meta_dict = { 'source': 'nasbench-201', 'num_graph': num_graph, 'n_atoms_per_mol_dist': n_atoms_per_mol[:51], 'max_node': max(n_atom_list), 'max_bond': max(n_bond_list), 'atom_type_dist': atom_count_list, 'bond_type_dist': bond_count_list, 'valencies': valencies, 'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0], 'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]), 'transition_E': transition_E.tolist(), } with open(f'{filename}.meta.json', 'w') as f: json.dump(meta_dict, f) return meta_dict class Dataset(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' self.source = source # self.api = API(source) # Initialize NAS-Bench-201 API # print('API loaded') super().__init__(root, transform, pre_transform, pre_filter) print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt self.data, self.slices = torch.load(self.processed_paths[0]) print('Dataset initialized') self.data.edge_attr = self.data.edge_attr.squeeze() self.data.idx = torch.arange(len(self.data.y)) print(f"self.data={self.data}, self.slices={self.slices}") @property def raw_file_names(self): return [] # NAS-Bench-201 data is loaded via the API, no raw files needed @property def processed_file_names(self): return [f'{self.source}.pt'] def process(self): source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' # self.api = API(source) data_list = [] # len_data = len(self.api) len_data = 15625 def check_valid_graph(nodes, edges): if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: return False if nodes[0] != 'input' or nodes[-1] != 'output': return False for i in range(0, len(nodes)): if edges[i][i] == 1: return False for i in range(1, len(nodes) - 1): if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': return False for i in range(0, len(nodes)): for j in range(i, len(nodes)): if edges[i, j] == 1 and nodes[j] == 'input': return False for i in range(0, len(nodes)): for j in range(i, len(nodes)): if edges[i, j] == 1 and nodes[i] == 'output': return False flag = 0 for i in range(0,len(nodes)): if edges[i,-1] == 1: flag = 1 break if flag == 0: return False return True def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5): # print(ori_nodes) # print(ori_edges) ori_edges = np.array(ori_edges) # ori_nodes = np.array(ori_nodes) nasbench_201_node_num = 8 # random.seed(random_seed) nodes_num = random.randint(min_nodes, max_nodes) # print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}') add_num = nodes_num - nasbench_201_node_num # ori_nodes, ori_edges = parse_architecture_string(arch_str) add_nodes = [] print(f'add_num: {add_num}') for i in range(add_num): add_nodes.append(random.choice(num_to_op[1:-1])) # print(add_nodes) print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}') print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}') nodes = ori_nodes[:-1] + add_nodes + ['output'] edges = np.zeros((nodes_num , nodes_num)) edges[:6, :6] = ori_edges[:6, :6] edges[0:8, -1] = ori_edges[0:8 , -1] for i in range(0, nodes_num): for j in range(max(7,i + 1), nodes_num): rand = random.random() if rand < random_ratio: edges[i, j] = 1 if nodes_num < max_nodes: edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0) while len(nodes) < max_nodes: nodes.append('none') print(f'edges size: {edges.shape}, nodes size: {len(nodes)}') return edges,nodes def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): # def graph_to_graph_data(graph): ops = graph[1] adj = graph[0] nodes = [] for op in ops: nodes.append(op_type[op]) x = torch.LongTensor(nodes) edges_list = [] edge_type = [] for start in range(len(ops)): for end in range(len(ops)): if adj[start][end] == 1: edges_list.append((start, end)) edge_type.append(1) edges_list.append((end, start)) edge_type.append(1) edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) print(y, idx) if y > 1600: print(f'idx={idx}, y={y}') y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) else: print(f'idx={idx}, y={y}') y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) return None return data graph_list = [] class Args: pass args = Args() args.trainval = True args.augtype = 'none' args.repeat = 1 args.score = 'hook_logdet' args.sigma = 0.05 args.nasspace = 'nasbench201' args.batch_size = 128 args.GPU = '0' args.dataset = 'cifar10' args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' args.data_loc = '../cifardata/' args.seed = 777 args.init = '' args.save_loc = 'results' args.save_string = 'naswot' args.dropout = False args.maxofn = 1 args.n_samples = 100 args.n_runs = 500 args.stem_out_channels = 16 args.num_stacks = 3 args.num_modules_per_stack = 3 args.num_labels = 1 searchspace = nasspace.get_search_space(args) train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) device = torch.device('cuda:2') with tqdm(total = len_data) as pbar: active_nodes = set() file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' with open(file_path, 'r') as f: graph_list = json.load(f) i = 0 flex_graph_list = [] flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' for graph in graph_list: print(f'iterate every graph in graph_list, here is {i}') # arch_info = self.api.query_meta_info_by_index(i) # results = self.api.query_by_index(i, 'cifar100') arch_info = graph['arch_str'] # results = # nodes, edges = parse_architecture_string(arch_info.arch_str) # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) for op in ops: if op not in active_nodes: active_nodes.add(op) data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) i += 1 if data is None: pbar.update(1) continue # with open(flex_graph_path, 'a') as f: # flex_graph = { # 'adj_matrix': adj_matrix, # 'ops': ops, # } # json.dump(flex_graph, f) flex_graph_list.append({ 'adj_matrix':adj_matrix, 'ops': ops, }) if i < 3: print(f"i={i}, data={data}") with open(f'{i}.json', 'w') as f: f.write(str(data.x)) f.write(str(data.edge_index)) f.write(str(data.edge_attr)) data_list.append(data) # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5) # flex_graph_list.append({ # 'adj_matrix':new_adj.tolist(), # 'ops': new_ops, # }) # data_list.append(graph_to_graph_data((new_adj, new_ops))) # graph_list.append({ # "adj_matrix": adj_matrix, # "ops": ops, # "arch_str": arch_info.arch_str, # "idx": i, # "train": [{ # "iepoch": result.get_train()['iepoch'], # "loss": result.get_train()['loss'], # "accuracy": result.get_train()['accuracy'], # "cur_time": result.get_train()['cur_time'], # "all_time": result.get_train()['all_time'], # "seed": seed, # }for seed, result in results.items()], # "valid": [{ # "iepoch": result.get_eval('x-valid')['iepoch'], # "loss": result.get_eval('x-valid')['loss'], # "accuracy": result.get_eval('x-valid')['accuracy'], # "cur_time": result.get_eval('x-valid')['cur_time'], # "all_time": result.get_eval('x-valid')['all_time'], # "seed": seed, # }for seed, result in results.items()], # "test": [{ # "iepoch": result.get_eval('x-test')['iepoch'], # "loss": result.get_eval('x-test')['loss'], # "accuracy": result.get_eval('x-test')['accuracy'], # "cur_time": result.get_eval('x-test')['cur_time'], # "all_time": result.get_eval('x-test')['all_time'], # "seed": seed, # }for seed, result in results.items()] # }) # i += 1 pbar.update(1) for graph in graph_list: adj_matrix = graph['adj_matrix'] if isinstance(adj_matrix, np.ndarray): adj_matrix = adj_matrix.tolist() graph['adj_matrix'] = adj_matrix ops = graph['ops'] if isinstance(ops, np.ndarray): ops = ops.tolist() graph['ops'] = ops with open(f'nasbench-201-graph.json', 'w') as f: json.dump(graph_list, f) # with open(flex_graph_path, 'w') as f: # json.dump(flex_graph_list, f) torch.save(self.collate(data_list), self.processed_paths[0]) # def parse_architecture_string(arch_str): # stages = arch_str.split('+') # nodes = ['input'] # edges = [] # for stage in stages: # operations = stage.strip('|').split('|') # for op in operations: # operation, idx = op.split('~') # idx = int(idx) # edges.append((idx, len(nodes))) # Add edge from idx to the new node # nodes.append(operation) # nodes.append('output') # Add the output node # return nodes, edges # def create_graph(nodes, edges): # G = nx.DiGraph() # for i, node in enumerate(nodes): # G.add_node(i, label=node) # G.add_edges_from(edges) # return G # def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): # nodes, edges = parse_architecture_string(arch_str) # node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary # assert 0 not in node_labels, f'Invalid node label: {node_labels}' # x = torch.LongTensor(node_labels) # print(f'in initialize Dataset, arch_to_Graph x={x}') # edges_list = [(start, end) for start, end in edges] # edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type # edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous() # edge_type = torch.tensor(edge_type, dtype=torch.long) # edge_attr = edge_type.view(-1, 1) # if target3 is not None: # y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1) # elif target2 is not None: # y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1) # else: # y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) # print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}') # data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) # return data, nodes # bonds = { # 'nor_conv_1x1': 1, # 'nor_conv_3x3': 2, # 'avg_pool_3x3': 3, # 'skip_connect': 4, # 'output': 5, # 'none': 6, # 'input': 7 # } # # Prepare to process NAS-Bench-201 data # data_list = [] # len_data = len(self.api) # Number of architectures # with tqdm(total=len_data) as pbar: # for arch_index in range(len_data): # arch_info = self.api.query_meta_info_by_index(arch_index) # arch_str = arch_info.arch_str # sa = np.random.rand() # Placeholder for synthetic accessibility # sc = np.random.rand() # Placeholder for substructure count # target = np.random.rand() # Placeholder for target value # target2 = np.random.rand() # Placeholder for second target value # target3 = np.random.rand() # Placeholder for third target value # data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3) # data_list.append(data) # pbar.update(1) # torch.save(self.collate(data_list), self.processed_paths[0]) class Dataset_origin(InMemoryDataset): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): self.target_prop = target_prop self.source = source super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): return [f'{self.source}.csv.gz'] @property def processed_file_names(self): return [f'{self.source}.pt'] def process(self): RDLogger.DisableLog('rdApp.*') data_path = osp.join(self.raw_dir, self.raw_file_names[0]) data_df = pd.read_csv(data_path) def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None): type_idx = [] heavy_atom_indices, active_atoms = [], [] for atom in mol.GetAtoms(): if atom.GetAtomicNum() != 1: type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2) heavy_atom_indices.append(atom.GetIdx()) active_atoms.append(atom.GetSymbol()) if valid_atoms is not None: if not atom.GetSymbol() in valid_atoms: return None, None x = torch.LongTensor(type_idx) edges_list = [] edge_type = [] for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() if start in heavy_atom_indices and end in heavy_atom_indices: start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end) edges_list.append((start_new, end_new)) edge_type.append(bonds[bond.GetBondType()]) edges_list.append((end_new, start_new)) edge_type.append(bonds[bond.GetBondType()]) edge_index = torch.tensor(edges_list, dtype=torch.long).t() edge_type = torch.tensor(edge_type, dtype=torch.long) edge_attr = edge_type if target3 is not None: y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1) elif target2 is not None: y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1) else: y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) if self.pre_transform is not None: data = self.pre_transform(data) return data, active_atoms # Loop through every row in the DataFrame and apply the function data_list = [] len_data = len(data_df) with tqdm(total=len_data) as pbar: # --- data processing start --- active_atoms = set() for i, (sms, df_row) in enumerate(data_df.iterrows()): if i == sms: sms = df_row['smiles'] mol = Chem.MolFromSmiles(sms, sanitize=False) if len(self.target_prop.split('-')) == 2: target1, target2 = self.target_prop.split('-') data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2]) elif len(self.target_prop.split('-')) == 3: target1, target2, target3 = self.target_prop.split('-') data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3]) else: data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop]) active_atoms.update(cur_active_atoms) data_list.append(data) pbar.update(1) torch.save(self.collate(data_list), self.processed_paths[0]) def parse_architecture_string(arch_str, padding=0): # print(arch_str) steps = arch_str.split('+') nodes = ['input'] # Start with input node ori_adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 1 ,0 ,0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], # [0, 0, 0, 0, 0, 0, 0, 0]]) [0, 0, 0, 0, 0, 0, 0, 0]] # adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 1 ,0 ,0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1], # [0, 0, 0, 0, 0, 0, 0, 0]]) [0, 0, 0, 0, 0, 0, 0, 0]] steps = arch_str.split('+') steps_coding = ['0', '0', '1', '0', '1', '2'] cont = 0 for step in steps: step = step.strip('|').split('|') for node in step: n, idx = node.split('~') assert idx == steps_coding[cont] cont += 1 nodes.append(n) nodes.append('output') # Add output node ori_nodes = nodes.copy() if padding > 0: for i in range(padding): nodes.append('none') for adj_row in adj_mat: for i in range(padding): adj_row.append(0) # adj_mat = np.append(adj_mat, np.zeros((padding, len(nodes)))) for i in range(padding): adj_mat.append([0] * len(nodes)) # print(nodes) # print(adj_mat) # print(len(adj_mat)) # print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}') return nodes, adj_mat, ori_nodes, ori_adj_mat def create_adj_matrix_and_ops(nodes, edges): num_nodes = len(nodes) adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int) for (src, dst) in edges: adj_matrix[src][dst] = 1 return adj_matrix, nodes class DataInfos(AbstractDatasetInfos): def __init__(self, datamodule, cfg, dataset): tasktype_dict = { 'hiv_b': 'classification', 'bace_b': 'classification', 'bbbp_b': 'classification', 'O2': 'regression', 'N2': 'regression', 'CO2': 'regression', } task_name = cfg.dataset.task_name self.task = task_name self.task_type = tasktype_dict.get(task_name, "regression") self.ensure_connected = cfg.model.ensure_connected # self.api = dataset.api datadir = cfg.dataset.datadir base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') data_root = os.path.join(base_path, datadir, 'raw') graphs = [] length = 15625 ops_type = {} len_ops = set() # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') def read_adj_ops_from_json(filename): with open(filename, 'r') as json_file: data = json.load(json_file) adj_ops_pairs = [] for item in data: print(item) adj_matrix = np.array(item['adj_matrix']) ops = item['ops'] ops = [op_type[op] for op in ops] adj_ops_pairs.append((adj_matrix, ops)) return adj_ops_pairs # for i in range(length): # arch_info = self.api.query_meta_info_by_index(i) # nodes, edges = parse_architecture_string(arch_info.arch_str) # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) # if i < 5: # print("Adjacency Matrix:") # print(adj_matrix) # print("Operations List:") # print(ops) # for op in ops: # if op not in ops_type: # ops_type[op] = len(ops_type) # len_ops.add(len(ops)) # graphs.append((adj_matrix, ops)) # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') # check first five graphs for i in range(5): print(f'graph {i} : {graphs[i]}') # print(f'ops_type: {ops_type}') meta_dict = new_graphs_to_json(graphs, 'nasbench-201') self.base_path = base_path self.active_nodes = meta_dict['active_nodes'] self.max_n_nodes = meta_dict['max_n_nodes'] self.original_max_n_nodes = meta_dict['max_n_nodes'] self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph']) self.edge_types = torch.Tensor(meta_dict['edge_type_list']) self.transition_E = torch.Tensor(meta_dict['transition_E']) self.node_decoder = meta_dict['active_nodes'] node_types = torch.Tensor(meta_dict['node_type_list']) active_index = (node_types > 0).nonzero().squeeze() self.node_types = torch.Tensor(meta_dict['node_type_list'])[active_index] self.nodes_dist = DistributionNodes(self.n_nodes) self.active_index = active_index val_len = 3 * self.original_max_n_nodes - 2 meta_val = torch.Tensor(meta_dict['valencies']) self.valency_distribution = torch.zeros(val_len) val_len = min(val_len, len(meta_val)) self.valency_distribution[:val_len] = meta_val[:val_len] self.y_prior = None self.train_ymin = [] self.train_ymax = [] class DataInfos_origin(AbstractDatasetInfos): def __init__(self, datamodule, cfg): tasktype_dict = { 'hiv_b': 'classification', 'bace_b': 'classification', 'bbbp_b': 'classification', 'O2': 'regression', 'N2': 'regression', 'CO2': 'regression', } task_name = cfg.dataset.task_name self.task = task_name self.task_type = tasktype_dict.get(task_name, "regression") self.ensure_connected = cfg.model.ensure_connected datadir = cfg.dataset.datadir base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json') data_root = os.path.join(base_path, datadir, 'raw') if os.path.exists(meta_filename): with open(meta_filename, 'r') as f: meta_dict = json.load(f) else: meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index) self.base_path = base_path self.active_atoms = meta_dict['active_atoms'] self.max_n_nodes = meta_dict['max_node'] self.original_max_n_nodes = meta_dict['max_node'] self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist']) self.edge_types = torch.Tensor(meta_dict['bond_type_dist']) self.transition_E = torch.Tensor(meta_dict['transition_E']) self.atom_decoder = meta_dict['active_atoms'] node_types = torch.Tensor(meta_dict['atom_type_dist']) active_index = (node_types > 0).nonzero().squeeze() self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index] self.nodes_dist = DistributionNodes(self.n_nodes) self.active_index = active_index val_len = 3 * self.original_max_n_nodes - 2 meta_val = torch.Tensor(meta_dict['valencies']) self.valency_distribution = torch.zeros(val_len) val_len = min(val_len, len(meta_val)) self.valency_distribution[:val_len] = meta_val[:val_len] self.y_prior = None self.train_ymin = [] self.train_ymax = [] def compute_meta(root, source_name, train_index, test_index): # initialize the periodic table # 118 elements + 1 for * # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types. pt = Chem.GetPeriodicTable() atom_name_list = [] atom_count_list = [] for i in range(2, 119): atom_name_list.append(pt.GetElementSymbol(i)) atom_count_list.append(0) atom_name_list.append('*') atom_count_list.append(0) n_atoms_per_mol = [0] * 500 bond_count_list = [0, 0, 0, 0, 0] bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} valencies = [0] * 500 tansition_E = np.zeros((118, 118, 5)) # Load the data from the source file filename = f'{source_name}.csv.gz' df = pd.read_csv(f'{root}/{filename}') all_index = list(range(len(df))) non_test_index = list(set(all_index) - set(test_index)) df = df.iloc[non_test_index] # extract the smiles from the dataframe tot_smiles = df['smiles'].tolist() n_atom_list = [] n_bond_list = [] for i, sms in enumerate(tot_smiles): try: mol = Chem.MolFromSmiles(sms) except: continue n_atom = mol.GetNumHeavyAtoms() n_bond = mol.GetNumBonds() n_atom_list.append(n_atom) n_bond_list.append(n_bond) n_atoms_per_mol[n_atom] += 1 cur_atom_count_arr = np.zeros(118) for atom in mol.GetAtoms(): symbol = atom.GetSymbol() if symbol == 'H': continue elif symbol == '*': atom_count_list[-1] += 1 cur_atom_count_arr[-1] += 1 else: atom_count_list[atom.GetAtomicNum()-2] += 1 cur_atom_count_arr[atom.GetAtomicNum()-2] += 1 try: valencies[int(atom.GetExplicitValence())] += 1 except: print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence())) tansition_E_temp = np.zeros((118, 118, 5)) for bond in mol.GetBonds(): start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H': continue if start_atom.GetSymbol() == '*': start_index = 117 else: start_index = start_atom.GetAtomicNum() - 2 if end_atom.GetSymbol() == '*': end_index = 117 else: end_index = end_atom.GetAtomicNum() - 2 bond_type = bond.GetBondType() bond_index = bond_type_to_index[bond_type] bond_count_list[bond_index] += 2 # Update the transition matrix # The transition matrix is symmetric, so we update both directions # We also update the temporary transition matrix to check for errors # in the atom count tansition_E[start_index, end_index, bond_index] += 2 tansition_E[end_index, start_index, bond_index] += 2 tansition_E_temp[start_index, end_index, bond_index] += 2 tansition_E_temp[end_index, start_index, bond_index] += 2 bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2 cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118 cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118 tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1) assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}' n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol) n_atoms_per_mol = n_atoms_per_mol.tolist()[:51] atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list) print('processed meta info: ------', filename, '------') print('len atom_count_list', len(atom_count_list)) print('len atom_name_list', len(atom_name_list)) active_atoms = np.array(atom_name_list)[atom_count_list > 0] active_atoms = active_atoms.tolist() atom_count_list = atom_count_list.tolist() bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list) bond_count_list = bond_count_list.tolist() valencies = np.array(valencies) / np.sum(valencies) valencies = valencies.tolist() no_edge = np.sum(tansition_E, axis=-1) == 0 first_elt = tansition_E[:, :, 0] first_elt[no_edge] = 1 tansition_E[:, :, 0] = first_elt tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True) meta_dict = { 'source': source_name, 'num_graph': len(n_atom_list), 'n_atoms_per_mol_dist': n_atoms_per_mol, 'max_node': max(n_atom_list), 'max_bond': max(n_bond_list), 'atom_type_dist': atom_count_list, 'bond_type_dist': bond_count_list, 'valencies': valencies, 'active_atoms': active_atoms, 'num_atom_type': len(active_atoms), 'transition_E': tansition_E.tolist(), } with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f: json.dump(meta_dict, f) return meta_dict if __name__ == "__main__": dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)