nasbenchipynb
This commit is contained in:
parent
6dc5ef1da8
commit
7831979db7
@ -21,7 +21,7 @@ from sklearn.model_selection import train_test_split
|
|||||||
import utils as utils
|
import utils as utils
|
||||||
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
||||||
from diffusion.distributions import DistributionNodes
|
from diffusion.distributions import DistributionNodes
|
||||||
from nas_201_api import NASBench201API
|
|
||||||
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||||
|
|
||||||
class DataModule(AbstractDataModule):
|
class DataModule(AbstractDataModule):
|
||||||
@ -48,17 +48,17 @@ class DataModule(AbstractDataModule):
|
|||||||
# Load the dataset to the memory
|
# Load the dataset to the memory
|
||||||
# Dataset has target property, root path, and transform
|
# Dataset has target property, root path, and transform
|
||||||
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
||||||
# print("len dataset", len(dataset))
|
print("len dataset", len(dataset))
|
||||||
# def print_data(dataset):
|
def print_data(dataset):
|
||||||
# print("dataset", dataset)
|
print("dataset", dataset)
|
||||||
# print("dataset keys", dataset.keys)
|
print("dataset keys", dataset.keys)
|
||||||
# print("dataset x", dataset.x)
|
print("dataset x", dataset.x)
|
||||||
# print("dataset edge_index", dataset.edge_index)
|
print("dataset edge_index", dataset.edge_index)
|
||||||
# print("dataset edge_attr", dataset.edge_attr)
|
print("dataset edge_attr", dataset.edge_attr)
|
||||||
# print("dataset y", dataset.y)
|
print("dataset y", dataset.y)
|
||||||
# print("")
|
print("")
|
||||||
# print_data(dataset=dataset[0])
|
print_data(dataset=dataset[0])
|
||||||
# print_data(dataset=dataset[1])
|
print_data(dataset=dataset[1])
|
||||||
|
|
||||||
|
|
||||||
if len(self.task.split('-')) == 2:
|
if len(self.task.split('-')) == 2:
|
||||||
@ -155,30 +155,7 @@ class Dataset(InMemoryDataset):
|
|||||||
def processed_file_names(self):
|
def processed_file_names(self):
|
||||||
return [f'{self.source}.pt']
|
return [f'{self.source}.pt']
|
||||||
|
|
||||||
def create_adj_matrix_and_ops(nodes, edges):
|
|
||||||
num_nodes = len(nodes)
|
|
||||||
adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
|
|
||||||
for (src, dst) in edges:
|
|
||||||
adj_matrix[src][dst] = 1
|
|
||||||
return adj_matrix, nodes
|
|
||||||
|
|
||||||
def parse_architecture_string(arch_str):
|
|
||||||
print(arch_str)
|
|
||||||
steps = arch_str.split('+')
|
|
||||||
nodes = ['input'] # Start with input node
|
|
||||||
edges = []
|
|
||||||
for i, step in enumerate(steps):
|
|
||||||
step = step.strip('|').split('|')
|
|
||||||
for node in step:
|
|
||||||
op, idx = node.split('~')
|
|
||||||
edges.append((int(idx), i+1)) # i+1 because 0 is input node
|
|
||||||
nodes.append(op)
|
|
||||||
nodes.append('output') # Add output node
|
|
||||||
return nodes, edges
|
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
# return
|
|
||||||
api = NASBench201API('./NAS-Bench-201-v1_1-096897.pth')
|
|
||||||
RDLogger.DisableLog('rdApp.*')
|
RDLogger.DisableLog('rdApp.*')
|
||||||
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
|
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
|
||||||
data_df = pd.read_csv(data_path)
|
data_df = pd.read_csv(data_path)
|
||||||
@ -223,65 +200,26 @@ class Dataset(InMemoryDataset):
|
|||||||
return data, active_atoms
|
return data, active_atoms
|
||||||
|
|
||||||
# Loop through every row in the DataFrame and apply the function
|
# Loop through every row in the DataFrame and apply the function
|
||||||
# data_list = []
|
|
||||||
|
|
||||||
# len_data = len(data_df)
|
|
||||||
len_data = 15625
|
|
||||||
data_list = []
|
data_list = []
|
||||||
bonds = {
|
len_data = len(data_df)
|
||||||
'nor_conv_1x1': 1,
|
with tqdm(total=len_data) as pbar:
|
||||||
'nor_conv_3x3': 2,
|
# --- data processing start ---
|
||||||
'avg_pool_3x3': 3,
|
active_atoms = set()
|
||||||
'skip_connect': 4,
|
for i, (sms, df_row) in enumerate(data_df.iterrows()):
|
||||||
'input': 0,
|
if i == sms:
|
||||||
'output': 5
|
sms = df_row['smiles']
|
||||||
}
|
mol = Chem.MolFromSmiles(sms, sanitize=False)
|
||||||
|
if len(self.target_prop.split('-')) == 2:
|
||||||
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
target1, target2 = self.target_prop.split('-')
|
||||||
nodes, edges = Dataset.parse_architecture_string(arch_str)
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
|
||||||
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
elif len(self.target_prop.split('-')) == 3:
|
||||||
x = torch.LongTensor(node_labels)
|
target1, target2, target3 = self.target_prop.split('-')
|
||||||
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
|
||||||
edges_list = [(start, end) for start, end in edges]
|
else:
|
||||||
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
|
||||||
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
|
active_atoms.update(cur_active_atoms)
|
||||||
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
data_list.append(data)
|
||||||
edge_attr = edge_type.view(-1, 1)
|
pbar.update(1)
|
||||||
|
|
||||||
if target3 is not None:
|
|
||||||
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
|
|
||||||
elif target2 is not None:
|
|
||||||
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
|
|
||||||
else:
|
|
||||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
|
||||||
|
|
||||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
|
||||||
return data, nodes
|
|
||||||
# --- data processing start ---
|
|
||||||
# active_atoms = set()
|
|
||||||
# for i, (sms, df_row) in enumerate(data_df.iterrows()):
|
|
||||||
# if i == sms:
|
|
||||||
# sms = df_row['smiles']
|
|
||||||
# mol = Chem.MolFromSmiles(sms, sanitize=False)
|
|
||||||
# if len(self.target_prop.split('-')) == 2:
|
|
||||||
# target1, target2 = self.target_prop.split('-')
|
|
||||||
# data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
|
|
||||||
# elif len(self.target_prop.split('-')) == 3:
|
|
||||||
# target1, target2, target3 = self.target_prop.split('-')
|
|
||||||
# data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
|
|
||||||
# else:
|
|
||||||
# data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
|
|
||||||
# active_atoms.update(cur_active_atoms)
|
|
||||||
# data_list.append(data)
|
|
||||||
# pbar.update(1)
|
|
||||||
for arch_index in range(len_data):
|
|
||||||
arch_info = api.get_arch(arch_index)
|
|
||||||
arch_str = arch_info['arch_str']
|
|
||||||
|
|
||||||
nodes, edges = Dataset.parse_architecture_string(arch_str)
|
|
||||||
adj_matrix, nodes = Dataset.create_adj_matrix_and_ops(nodes, edges)
|
|
||||||
|
|
||||||
data, cur_active_atoms = graph
|
|
||||||
|
|
||||||
torch.save(self.collate(data_list), self.processed_paths[0])
|
torch.save(self.collate(data_list), self.processed_paths[0])
|
||||||
|
|
||||||
@ -296,10 +234,8 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
'N2': 'regression',
|
'N2': 'regression',
|
||||||
'CO2': 'regression',
|
'CO2': 'regression',
|
||||||
}
|
}
|
||||||
|
|
||||||
task_name = cfg.dataset.task_name
|
task_name = cfg.dataset.task_name
|
||||||
self.task = task_name
|
self.task = task_name
|
||||||
print(self.task)
|
|
||||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||||
self.ensure_connected = cfg.model.ensure_connected
|
self.ensure_connected = cfg.model.ensure_connected
|
||||||
|
|
||||||
@ -473,181 +409,6 @@ def compute_meta(root, source_name, train_index, test_index):
|
|||||||
|
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
|
||||||
op_to_atom = {
|
|
||||||
'input': 'Si', # Hydrogen for input
|
|
||||||
'nor_conv_1x1': 'C', # Carbon for 1x1 convolution
|
|
||||||
'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution
|
|
||||||
'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling
|
|
||||||
'skip_connect': 'P', # Phosphorus for skip connection
|
|
||||||
'none': 'S', # Sulfur for no operation
|
|
||||||
'output': 'He' # Helium for output
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_sample_nasbench_graph():
|
|
||||||
adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
|
|
||||||
[0, 0, 0, 1, 0, 1, 0, 0],
|
|
||||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
||||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 0]])
|
|
||||||
ops = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
|
||||||
return adj_mat, ops
|
|
||||||
|
|
||||||
def nasbench_to_molecule(adj_mat, ops):
|
|
||||||
mol = Chem.RWMol() # Create a new editable molecule
|
|
||||||
atom_map = {} # Map to keep track of node to atom mapping
|
|
||||||
|
|
||||||
# Add atoms to the molecule
|
|
||||||
for i, op in enumerate(ops):
|
|
||||||
atom_type = op_to_atom.get(op, 'C') # Default to Carbon if operation not found
|
|
||||||
atom = Chem.Atom(atom_type) # Create an atom of the specified type
|
|
||||||
idx = mol.AddAtom(atom)
|
|
||||||
atom_map[i] = idx
|
|
||||||
|
|
||||||
# Add bonds to the molecule
|
|
||||||
for i in range(adj_mat.shape[0]):
|
|
||||||
for j in range(adj_mat.shape[1]):
|
|
||||||
if adj_mat[i, j] == 1:
|
|
||||||
mol.AddBond(atom_map[i], atom_map[j], Chem.rdchem.BondType.SINGLE)
|
|
||||||
|
|
||||||
return mol
|
|
||||||
|
|
||||||
def compute_meta_graph(root, source_name, train_index, test_index):
|
|
||||||
# initialize the periodic table
|
|
||||||
# 118 elements + 1 for *
|
|
||||||
# Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
|
|
||||||
pt = Chem.GetPeriodicTable()
|
|
||||||
atom_name_list = []
|
|
||||||
atom_count_list = []
|
|
||||||
for i in range(2, 119):
|
|
||||||
atom_name_list.append(pt.GetElementSymbol(i))
|
|
||||||
atom_count_list.append(0)
|
|
||||||
atom_name_list.append('*')
|
|
||||||
atom_count_list.append(0)
|
|
||||||
n_atoms_per_mol = [0] * 500
|
|
||||||
bond_count_list = [0, 0, 0, 0, 0]
|
|
||||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
|
||||||
valencies = [0] * 500
|
|
||||||
tansition_E = np.zeros((118, 118, 5))
|
|
||||||
|
|
||||||
# Load the data from the source file
|
|
||||||
filename = f'{source_name}.csv.gz'
|
|
||||||
df = pd.read_csv(f'{root}/{filename}')
|
|
||||||
all_index = list(range(len(df)))
|
|
||||||
non_test_index = list(set(all_index) - set(test_index))
|
|
||||||
df = df.iloc[non_test_index]
|
|
||||||
# extract the smiles from the dataframe
|
|
||||||
tot_smiles = df['smiles'].tolist()
|
|
||||||
|
|
||||||
n_atom_list = []
|
|
||||||
n_bond_list = []
|
|
||||||
for i, sms in enumerate(tot_smiles):
|
|
||||||
try:
|
|
||||||
mol = Chem.MolFromSmiles(sms)
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
n_atom = mol.GetNumHeavyAtoms()
|
|
||||||
n_bond = mol.GetNumBonds()
|
|
||||||
n_atom_list.append(n_atom)
|
|
||||||
n_bond_list.append(n_bond)
|
|
||||||
|
|
||||||
n_atoms_per_mol[n_atom] += 1
|
|
||||||
cur_atom_count_arr = np.zeros(118)
|
|
||||||
|
|
||||||
for atom in mol.GetAtoms():
|
|
||||||
symbol = atom.GetSymbol()
|
|
||||||
if symbol == 'H':
|
|
||||||
continue
|
|
||||||
elif symbol == '*':
|
|
||||||
atom_count_list[-1] += 1
|
|
||||||
cur_atom_count_arr[-1] += 1
|
|
||||||
else:
|
|
||||||
atom_count_list[atom.GetAtomicNum()-2] += 1
|
|
||||||
cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
|
|
||||||
try:
|
|
||||||
valencies[int(atom.GetExplicitValence())] += 1
|
|
||||||
except:
|
|
||||||
print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
|
|
||||||
|
|
||||||
tansition_E_temp = np.zeros((118, 118, 5))
|
|
||||||
for bond in mol.GetBonds():
|
|
||||||
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
|
|
||||||
if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
|
|
||||||
continue
|
|
||||||
|
|
||||||
if start_atom.GetSymbol() == '*':
|
|
||||||
start_index = 117
|
|
||||||
else:
|
|
||||||
start_index = start_atom.GetAtomicNum() - 2
|
|
||||||
if end_atom.GetSymbol() == '*':
|
|
||||||
end_index = 117
|
|
||||||
else:
|
|
||||||
end_index = end_atom.GetAtomicNum() - 2
|
|
||||||
|
|
||||||
bond_type = bond.GetBondType()
|
|
||||||
bond_index = bond_type_to_index[bond_type]
|
|
||||||
bond_count_list[bond_index] += 2
|
|
||||||
|
|
||||||
# Update the transition matrix
|
|
||||||
# The transition matrix is symmetric, so we update both directions
|
|
||||||
# We also update the temporary transition matrix to check for errors
|
|
||||||
# in the atom count
|
|
||||||
|
|
||||||
tansition_E[start_index, end_index, bond_index] += 2
|
|
||||||
tansition_E[end_index, start_index, bond_index] += 2
|
|
||||||
tansition_E_temp[start_index, end_index, bond_index] += 2
|
|
||||||
tansition_E_temp[end_index, start_index, bond_index] += 2
|
|
||||||
|
|
||||||
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
|
||||||
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
|
|
||||||
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
|
|
||||||
tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
|
|
||||||
assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
|
||||||
|
|
||||||
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
||||||
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
|
||||||
|
|
||||||
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
|
||||||
print('processed meta info: ------', filename, '------')
|
|
||||||
print('len atom_count_list', len(atom_count_list))
|
|
||||||
print('len atom_name_list', len(atom_name_list))
|
|
||||||
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
|
||||||
active_atoms = active_atoms.tolist()
|
|
||||||
atom_count_list = atom_count_list.tolist()
|
|
||||||
|
|
||||||
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
|
||||||
bond_count_list = bond_count_list.tolist()
|
|
||||||
valencies = np.array(valencies) / np.sum(valencies)
|
|
||||||
valencies = valencies.tolist()
|
|
||||||
|
|
||||||
no_edge = np.sum(tansition_E, axis=-1) == 0
|
|
||||||
first_elt = tansition_E[:, :, 0]
|
|
||||||
first_elt[no_edge] = 1
|
|
||||||
tansition_E[:, :, 0] = first_elt
|
|
||||||
|
|
||||||
tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
meta_dict = {
|
|
||||||
'source': source_name,
|
|
||||||
'num_graph': len(n_atom_list),
|
|
||||||
'n_atoms_per_mol_dist': n_atoms_per_mol,
|
|
||||||
'max_node': max(n_atom_list),
|
|
||||||
'max_bond': max(n_bond_list),
|
|
||||||
'atom_type_dist': atom_count_list,
|
|
||||||
'bond_type_dist': bond_count_list,
|
|
||||||
'valencies': valencies,
|
|
||||||
'active_atoms': active_atoms,
|
|
||||||
'num_atom_type': len(active_atoms),
|
|
||||||
'transition_E': tansition_E.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(f'{root}/{source_name}.meta.json', "w") as f:
|
|
||||||
json.dump(meta_dict, f)
|
|
||||||
|
|
||||||
return meta_dict
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
47037
graph_dit/test_nasbench.ipynb
Normal file
47037
graph_dit/test_nasbench.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user