1407 lines
59 KiB
Python
1407 lines
59 KiB
Python
|
|
import sys
|
|
sys.path.append('../')
|
|
|
|
from nas_201_api import NASBench201API as API
|
|
|
|
import os
|
|
import os.path as osp
|
|
import pathlib
|
|
import json
|
|
import random
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from rdkit import Chem, RDLogger
|
|
from rdkit.Chem.rdchem import BondType as BT
|
|
from rdkit.Chem import rdchem
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import pandas as pd
|
|
from torch_geometric.data import Data, InMemoryDataset
|
|
from torch_geometric.loader import DataLoader
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
import utils as utils
|
|
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
|
|
from diffusion.distributions import DistributionNodes
|
|
from naswot.score_networks import get_nasbench201_idx_score
|
|
from naswot import nasspace
|
|
from naswot import datasets as dt
|
|
|
|
import networkx as nx
|
|
|
|
|
|
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
|
|
|
op_to_atom = {
|
|
'input': 'Si', # Hydrogen for input
|
|
'nor_conv_1x1': 'C', # Carbon for 1x1 convolution
|
|
'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution
|
|
'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling
|
|
'skip_connect': 'P', # Phosphorus for skip connection
|
|
'none': 'S', # Sulfur for no operation
|
|
'output': 'He' # Helium for output
|
|
}
|
|
|
|
op_type = {
|
|
'input': 0,
|
|
'nor_conv_1x1': 1,
|
|
'nor_conv_3x3': 2,
|
|
'avg_pool_3x3': 3,
|
|
'skip_connect': 4,
|
|
'none': 5,
|
|
'output': 6,
|
|
}
|
|
|
|
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
|
|
|
class DataModule(AbstractDataModule):
|
|
def __init__(self, cfg):
|
|
self.datadir = cfg.dataset.datadir
|
|
self.task = cfg.dataset.task_name
|
|
print("DataModule")
|
|
print("task", self.task)
|
|
print("datadir", self.datadir)
|
|
super().__init__(cfg)
|
|
|
|
def prepare_data(self) -> None:
|
|
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
|
print("target", target) # nasbench-201
|
|
# try:
|
|
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
|
# except NameError:
|
|
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
|
base_path = '/nfs/data3/hanzhang/nasbenchDiT'
|
|
root_path = os.path.join(base_path, self.datadir)
|
|
self.root_path = root_path
|
|
|
|
batch_size = self.cfg.train.batch_size
|
|
|
|
num_workers = self.cfg.train.num_workers
|
|
pin_memory = self.cfg.dataset.pin_memory
|
|
|
|
# Load the dataset to the memory
|
|
# Dataset has target property, root path, and transform
|
|
source = './NAS-Bench-201-v1_1-096897.pth'
|
|
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
|
self.dataset = dataset
|
|
# self.api = dataset.api
|
|
|
|
# if len(self.task.split('-')) == 2:
|
|
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
|
# else:
|
|
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
|
|
|
self.train_index, self.val_index, self.test_index, self.unlabeled_index = (
|
|
train_index, val_index, test_index, unlabeled_index)
|
|
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
|
if len(unlabeled_index) > 0:
|
|
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
|
|
|
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
|
self.train_dataset = train_dataset
|
|
self.test_dataset = test_dataset
|
|
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
|
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
|
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
|
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
|
|
|
|
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
|
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
|
|
|
training_iterations = len(train_dataset) // batch_size
|
|
self.training_iterations = training_iterations
|
|
|
|
def random_data_split(self, dataset):
|
|
# nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
|
|
# labeled_len = len(dataset) - nan_count
|
|
labeled_len = len(dataset)
|
|
full_idx = list(range(labeled_len))
|
|
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
|
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
|
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
|
|
unlabeled_index = list(range(labeled_len, len(dataset)))
|
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
|
|
return train_index, val_index, test_index, unlabeled_index
|
|
|
|
def fixed_split(self, dataset):
|
|
if self.task == 'O2-N2':
|
|
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
|
|
else:
|
|
raise ValueError('Invalid task name: {}'.format(self.task))
|
|
full_idx = list(range(len(dataset)))
|
|
full_idx = list(set(full_idx) - set(test_index))
|
|
train_ratio = 0.8
|
|
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
|
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
|
return train_index, val_index, test_index, []
|
|
|
|
# def parse_architecture_string(self, arch_str):
|
|
# stages = arch_str.split('+')
|
|
# nodes = ['input']
|
|
# edges = []
|
|
|
|
# for stage in stages:
|
|
# operations = stage.strip('|').split('|')
|
|
# for op in operations:
|
|
# operation, idx = op.split('~')
|
|
# idx = int(idx)
|
|
# edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
|
# nodes.append(operation)
|
|
# nodes.append('output') # Add the output node
|
|
# return nodes, edges
|
|
def parse_architecture_string(arch_str):
|
|
# print(arch_str)
|
|
steps = arch_str.split('+')
|
|
nodes = ['input'] # Start with input node
|
|
adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 1 ,0 ,0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0]])
|
|
steps = arch_str.split('+')
|
|
steps_coding = ['0', '0', '1', '0', '1', '2']
|
|
cont = 0
|
|
for step in steps:
|
|
step = step.strip('|').split('|')
|
|
for node in step:
|
|
n, idx = node.split('~')
|
|
assert idx == steps_coding[cont]
|
|
cont += 1
|
|
nodes.append(n)
|
|
nodes.append('output') # Add output node
|
|
return nodes, adj_mat
|
|
|
|
# def create_molecule_from_graph(nodes, edges):
|
|
def create_molecule_from_graph(self, graph):
|
|
nodes = graph.x
|
|
edges = graph.edge_index
|
|
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
|
atom_indices = {}
|
|
num_to_op = {
|
|
1 :'nor_conv_1x1',
|
|
2 :'nor_conv_3x3',
|
|
3 :'avg_pool_3x3',
|
|
4 :'skip_connect',
|
|
5 :'output',
|
|
6 :'none',
|
|
7 :'input'
|
|
}
|
|
|
|
# Extract node operations from the data object
|
|
|
|
# Add atoms to the molecule
|
|
for i, op_tensor in enumerate(nodes):
|
|
op = op_tensor.item()
|
|
if op == 0: continue
|
|
op = num_to_op[op]
|
|
atom_symbol = op_to_atom[op]
|
|
atom = Chem.Atom(atom_symbol)
|
|
atom_idx = mol.AddAtom(atom)
|
|
atom_indices[i] = atom_idx
|
|
|
|
# Add bonds to the molecule
|
|
edge_number = edges.shape[1]
|
|
for i in range(edge_number):
|
|
start = edges[0, i].item()
|
|
end = edges[1, i].item()
|
|
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
|
|
|
return mol
|
|
|
|
# def arch_str_to_smiles(self, arch_str):
|
|
# nodes, edges = self.parse_architecture_string(arch_str)
|
|
# mol = self.create_molecule_from_graph(nodes, edges)
|
|
# smiles = Chem.MolToSmiles(mol)
|
|
# return smiles
|
|
|
|
def get_train_graphs(self):
|
|
train_graphs = []
|
|
test_graphs = []
|
|
for graph in self.train_dataset:
|
|
train_graphs.append(graph)
|
|
for graph in self.test_dataset:
|
|
test_graphs.append(graph)
|
|
return train_graphs, test_graphs
|
|
|
|
|
|
# def get_train_smiles(self):
|
|
# filename = f'{self.task}.csv.gz'
|
|
# df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
|
# df_test = df.iloc[self.test_index]
|
|
# df = df.iloc[self.train_index]
|
|
# smiles_list = df['smiles'].tolist()
|
|
# smiles_list_test = df_test['smiles'].tolist()
|
|
# smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
|
# smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
|
# return smiles_list, smiles_list_test
|
|
|
|
def get_train_smiles(self):
|
|
train_smiles = []
|
|
test_smiles = []
|
|
|
|
for graph in self.train_dataset:
|
|
# print(f'idx={idx}')
|
|
# graph = self.train_dataset[idx]
|
|
print(graph.x)
|
|
print(graph.edge_index)
|
|
print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
|
|
mol = self.create_molecule_from_graph(graph)
|
|
train_smiles.append(Chem.MolToSmiles(mol))
|
|
|
|
# for idx in self.test_index:
|
|
for graph in self.test_dataset:
|
|
# graph = self.dataset[idx]
|
|
# mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
|
mol = self.create_molecule_from_graph(graph)
|
|
test_smiles.append(Chem.MolToSmiles(mol))
|
|
|
|
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
|
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
|
|
return train_smiles, test_smiles
|
|
|
|
def get_data_split(self):
|
|
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
|
|
|
def example_batch(self):
|
|
return next(iter(self.val_loader))
|
|
|
|
def train_dataloader(self):
|
|
return self.train_loader
|
|
|
|
def val_dataloader(self):
|
|
return self.val_loader
|
|
|
|
def test_dataloader(self):
|
|
return self.test_loader
|
|
|
|
|
|
|
|
|
|
class DataModule_original(AbstractDataModule):
|
|
def __init__(self, cfg):
|
|
self.datadir = cfg.dataset.datadir
|
|
self.task = cfg.dataset.task_name
|
|
print("DataModule")
|
|
print("task", self.task)
|
|
print("datadir`",self.datadir)
|
|
super().__init__(cfg)
|
|
|
|
def prepare_data(self) -> None:
|
|
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
|
print("target", target)
|
|
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
|
root_path = os.path.join(base_path, self.datadir)
|
|
self.root_path = root_path
|
|
|
|
batch_size = self.cfg.train.batch_size
|
|
|
|
num_workers = self.cfg.train.num_workers
|
|
pin_memory = self.cfg.dataset.pin_memory
|
|
|
|
# Load the dataset to the memory
|
|
# Dataset has target property, root path, and transform
|
|
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
|
|
|
|
if len(self.task.split('-')) == 2:
|
|
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
|
else:
|
|
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
|
|
|
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
|
|
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
|
if len(unlabeled_index) > 0:
|
|
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
|
|
|
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
|
self.train_dataset = train_dataset
|
|
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
|
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
|
print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
|
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
|
|
|
|
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
|
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
|
|
|
|
training_iterations = len(train_dataset) // batch_size
|
|
self.training_iterations = training_iterations
|
|
|
|
def random_data_split(self, dataset):
|
|
nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
|
|
labeled_len = len(dataset) - nan_count
|
|
full_idx = list(range(labeled_len))
|
|
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
|
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
|
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
|
|
unlabeled_index = list(range(labeled_len, len(dataset)))
|
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
|
|
return train_index, val_index, test_index, unlabeled_index
|
|
|
|
def fixed_split(self, dataset):
|
|
if self.task == 'O2-N2':
|
|
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
|
|
else:
|
|
raise ValueError('Invalid task name: {}'.format(self.task))
|
|
full_idx = list(range(len(dataset)))
|
|
full_idx = list(set(full_idx) - set(test_index))
|
|
train_ratio = 0.8
|
|
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
|
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
|
return train_index, val_index, test_index, []
|
|
|
|
def get_train_smiles(self):
|
|
filename = f'{self.task}.csv.gz'
|
|
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
|
df_test = df.iloc[self.test_index]
|
|
df = df.iloc[self.train_index]
|
|
smiles_list = df['smiles'].tolist()
|
|
smiles_list_test = df_test['smiles'].tolist()
|
|
smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
|
smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
|
return smiles_list, smiles_list_test
|
|
|
|
def get_data_split(self):
|
|
filename = f'{self.task}.csv.gz'
|
|
df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
|
df_val = df.iloc[self.val_index]
|
|
df_test = df.iloc[self.test_index]
|
|
df_train = df.iloc[self.train_index]
|
|
return df_train, df_val, df_test
|
|
|
|
def example_batch(self):
|
|
return next(iter(self.val_loader))
|
|
|
|
def train_dataloader(self):
|
|
return self.train_loader
|
|
|
|
def val_dataloader(self):
|
|
return self.val_loader
|
|
|
|
def test_dataloader(self):
|
|
return self.test_loader
|
|
|
|
def new_graphs_to_json(graphs, filename):
|
|
source_name = "nasbench-201"
|
|
num_graph = len(graphs)
|
|
|
|
node_name_list = []
|
|
node_count_list = []
|
|
node_name_list.append('*')
|
|
|
|
for op_name in op_type:
|
|
node_name_list.append(op_name)
|
|
node_count_list.append(0)
|
|
|
|
node_count_list.append(0)
|
|
n_nodes_per_graph = [0] * num_graph
|
|
edge_count_list = [0, 0]
|
|
valencies = [0] * (len(op_type) + 1)
|
|
transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
|
|
|
n_node_list = []
|
|
n_edge_list = []
|
|
|
|
for graph in graphs:
|
|
ops = graph[1]
|
|
adj = graph[0]
|
|
|
|
n_node = len(ops)
|
|
print(n_node)
|
|
n_edge = len(ops)
|
|
n_node_list.append(n_node)
|
|
n_edge_list.append(n_edge)
|
|
|
|
n_nodes_per_graph[n_node] += 1
|
|
cur_node_count_arr = np.zeros(len(op_type) + 1)
|
|
|
|
for op in ops:
|
|
node = op
|
|
# if node == '*':
|
|
# node_count_list[-1] += 1
|
|
# cur_node_count_arr[-1] += 1
|
|
# else:
|
|
node_count_list[node] += 1
|
|
cur_node_count_arr[node] += 1
|
|
try:
|
|
valencies[node] += 1
|
|
except:
|
|
print('int(op_type[node])', int(node))
|
|
|
|
transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
|
|
for i in range(n_node):
|
|
for j in range(n_node):
|
|
if i == j or adj[i][j] == 0:
|
|
continue
|
|
start_node, end_node = i, j
|
|
|
|
start_index = ops[start_node]
|
|
end_index = ops[end_node]
|
|
bond_index = 1
|
|
edge_count_list[bond_index] += 2
|
|
|
|
transition_E[start_index, end_index, bond_index] += 2
|
|
transition_E[end_index, start_index, bond_index] += 2
|
|
transition_E_temp[start_index, end_index, bond_index] += 2
|
|
transition_E_temp[end_index, start_index, bond_index] += 2
|
|
|
|
edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
|
|
cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
|
|
# print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
|
|
cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
|
|
transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
|
|
assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
|
|
|
|
n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph)
|
|
n_nodes_per_graph = n_nodes_per_graph.tolist()[:51]
|
|
|
|
node_count_list = np.array(node_count_list) / np.sum(node_count_list)
|
|
print('processed meta info: ------', filename, '------')
|
|
print('len node_count_list', len(node_count_list))
|
|
print('len node_name_list', len(node_name_list))
|
|
active_nodes = np.array(node_name_list)[node_count_list > 0]
|
|
active_nodes = active_nodes.tolist()
|
|
node_count_list = node_count_list.tolist()
|
|
|
|
edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list)
|
|
edge_count_list = edge_count_list.tolist()
|
|
valencies = np.array(valencies) / np.sum(valencies)
|
|
valencies = valencies.tolist()
|
|
|
|
no_edge = np.sum(transition_E, axis=-1) == 0
|
|
first_elt = transition_E[:, :, 0]
|
|
first_elt[no_edge] = 1
|
|
transition_E[:, :, 0] = first_elt
|
|
|
|
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
|
|
|
|
meta_dict = {
|
|
'source': source_name,
|
|
'num_graph': num_graph,
|
|
'n_nodes_per_graph': n_nodes_per_graph,
|
|
'max_n_nodes': max(n_node_list),
|
|
'max_n_edges': max(n_edge_list),
|
|
'node_type_list': node_count_list,
|
|
'edge_type_list': edge_count_list,
|
|
'valencies': valencies,
|
|
'active_nodes': active_nodes,
|
|
'num_active_nodes': len(active_nodes),
|
|
'transition_E': transition_E.tolist(),
|
|
}
|
|
|
|
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
|
|
json.dump(meta_dict, f)
|
|
|
|
return meta_dict
|
|
|
|
|
|
|
|
|
|
def graphs_to_json(graphs, filename):
|
|
bonds = {
|
|
'nor_conv_1x1': 1,
|
|
'nor_conv_3x3': 2,
|
|
'avg_pool_3x3': 3,
|
|
'skip_connect': 4,
|
|
'input': 7,
|
|
'output': 5,
|
|
'none': 6
|
|
}
|
|
|
|
source_name = "nas-bench-201"
|
|
num_graph = len(graphs)
|
|
pt = Chem.GetPeriodicTable()
|
|
atom_name_list = []
|
|
atom_count_list = []
|
|
for i in range(2, 119):
|
|
atom_name_list.append(pt.GetElementSymbol(i))
|
|
atom_count_list.append(0)
|
|
atom_name_list.append('*')
|
|
atom_count_list.append(0)
|
|
n_atoms_per_mol = [0] * 500
|
|
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
|
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
|
valencies = [0] * 500
|
|
transition_E = np.zeros((118, 118, 8))
|
|
|
|
n_atom_list = []
|
|
n_bond_list = []
|
|
# graphs = [(adj_matrix, ops), ...]
|
|
for graph in graphs:
|
|
ops = graph[1]
|
|
adj = graph[0]
|
|
n_atom = len(ops)
|
|
n_bond = len(ops)
|
|
n_atom_list.append(n_atom)
|
|
n_bond_list.append(n_bond)
|
|
|
|
n_atoms_per_mol[n_atom] += 1
|
|
cur_atom_count_arr = np.zeros(118)
|
|
|
|
for op in ops:
|
|
symbol = op_to_atom[op]
|
|
if symbol == 'H':
|
|
continue
|
|
elif symbol == '*':
|
|
atom_count_list[-1] += 1
|
|
cur_atom_count_arr[-1] += 1
|
|
else:
|
|
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
|
|
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
|
|
# print('symbol', symbol)
|
|
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
|
|
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
|
|
try:
|
|
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
|
except:
|
|
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
|
transition_E_temp = np.zeros((118, 118, 8))
|
|
# print(n_atom)
|
|
for i in range(n_atom):
|
|
for j in range(n_atom):
|
|
if i == j or adj[i][j] == 0:
|
|
continue
|
|
start_atom, end_atom = i, j
|
|
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
|
|
continue
|
|
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
|
|
continue
|
|
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
|
|
continue
|
|
|
|
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
|
|
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
|
|
bond_index = bonds[ops[end_atom]]
|
|
bond_count_list[bond_index] += 2
|
|
|
|
# print(start_index, end_index, bond_index)
|
|
|
|
transition_E[start_index, end_index, bond_index] += 2
|
|
transition_E[end_index, start_index, bond_index] += 2
|
|
transition_E_temp[start_index, end_index, bond_index] += 2
|
|
transition_E_temp[end_index, start_index, bond_index] += 2
|
|
|
|
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
|
print(bond_count_list)
|
|
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
|
|
# print(f'cur_tot_bond={cur_tot_bond}')
|
|
# find non-zero element in cur_tot_bond
|
|
# for i in range(118):
|
|
# for j in range(118):
|
|
# if cur_tot_bond[i][j] != 0:
|
|
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
|
|
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
|
|
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
|
|
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
|
|
# find non-zero element in transition_E
|
|
# for i in range(118):
|
|
# for j in range(118):
|
|
# if transition_E[i][j][0] != 0:
|
|
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
|
|
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
|
|
|
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
|
|
|
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
|
print('processed meta info: ------', filename, '------')
|
|
print('len atom_count_list', len(atom_count_list))
|
|
print('len atom_name_list', len(atom_name_list))
|
|
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
|
active_atoms = active_atoms.tolist()
|
|
atom_count_list = atom_count_list.tolist()
|
|
|
|
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
|
bond_count_list = bond_count_list.tolist()
|
|
valencies = np.array(valencies) / np.sum(valencies)
|
|
valencies = valencies.tolist()
|
|
|
|
no_edge = np.sum(transition_E, axis=-1) == 0
|
|
for i in range(118):
|
|
for j in range(118):
|
|
if no_edge[i][j] == False:
|
|
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
|
|
# print(f'no_edge: {no_edge}')
|
|
first_elt = transition_E[:, :, 0]
|
|
first_elt[no_edge] = 1
|
|
transition_E[:, :, 0] = first_elt
|
|
|
|
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
|
|
|
|
# find non-zero element in transition_E again
|
|
for i in range(118):
|
|
for j in range(118):
|
|
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
|
|
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
|
|
|
|
meta_dict = {
|
|
'source': 'nasbench-201',
|
|
'num_graph': num_graph,
|
|
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
|
|
'max_node': max(n_atom_list),
|
|
'max_bond': max(n_bond_list),
|
|
'atom_type_dist': atom_count_list,
|
|
'bond_type_dist': bond_count_list,
|
|
'valencies': valencies,
|
|
'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
|
|
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
|
|
'transition_E': transition_E.tolist(),
|
|
}
|
|
|
|
with open(f'{filename}.meta.json', 'w') as f:
|
|
json.dump(meta_dict, f)
|
|
return meta_dict
|
|
class Dataset(InMemoryDataset):
|
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
|
self.target_prop = target_prop
|
|
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
|
self.source = source
|
|
# self.api = API(source) # Initialize NAS-Bench-201 API
|
|
# print('API loaded')
|
|
super().__init__(root, transform, pre_transform, pre_filter)
|
|
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
|
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
print('Dataset initialized')
|
|
self.data.edge_attr = self.data.edge_attr.squeeze()
|
|
self.data.idx = torch.arange(len(self.data.y))
|
|
print(f"self.data={self.data}, self.slices={self.slices}")
|
|
|
|
@property
|
|
def raw_file_names(self):
|
|
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
return [f'{self.source}.pt']
|
|
|
|
def process(self):
|
|
source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
|
# self.api = API(source)
|
|
|
|
data_list = []
|
|
# len_data = len(self.api)
|
|
len_data = 15625
|
|
def check_valid_graph(nodes, edges):
|
|
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
|
return False
|
|
if nodes[0] != 'input' or nodes[-1] != 'output':
|
|
return False
|
|
for i in range(0, len(nodes)):
|
|
if edges[i][i] == 1:
|
|
return False
|
|
for i in range(1, len(nodes) - 1):
|
|
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
|
return False
|
|
for i in range(0, len(nodes)):
|
|
for j in range(i, len(nodes)):
|
|
if edges[i, j] == 1 and nodes[j] == 'input':
|
|
return False
|
|
for i in range(0, len(nodes)):
|
|
for j in range(i, len(nodes)):
|
|
if edges[i, j] == 1 and nodes[i] == 'output':
|
|
return False
|
|
flag = 0
|
|
for i in range(0,len(nodes)):
|
|
if edges[i,-1] == 1:
|
|
flag = 1
|
|
break
|
|
if flag == 0: return False
|
|
return True
|
|
|
|
def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
|
|
# print(ori_nodes)
|
|
# print(ori_edges)
|
|
|
|
ori_edges = np.array(ori_edges)
|
|
# ori_nodes = np.array(ori_nodes)
|
|
nasbench_201_node_num = 8
|
|
# random.seed(random_seed)
|
|
nodes_num = random.randint(min_nodes, max_nodes)
|
|
# print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
|
|
add_num = nodes_num - nasbench_201_node_num
|
|
# ori_nodes, ori_edges = parse_architecture_string(arch_str)
|
|
add_nodes = []
|
|
print(f'add_num: {add_num}')
|
|
for i in range(add_num):
|
|
add_nodes.append(random.choice(num_to_op[1:-1]))
|
|
# print(add_nodes)
|
|
print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}')
|
|
print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}')
|
|
nodes = ori_nodes[:-1] + add_nodes + ['output']
|
|
edges = np.zeros((nodes_num , nodes_num))
|
|
edges[:6, :6] = ori_edges[:6, :6]
|
|
edges[0:8, -1] = ori_edges[0:8 , -1]
|
|
for i in range(0, nodes_num):
|
|
for j in range(max(7,i + 1), nodes_num):
|
|
rand = random.random()
|
|
if rand < random_ratio:
|
|
edges[i, j] = 1
|
|
if nodes_num < max_nodes:
|
|
edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0)
|
|
while len(nodes) < max_nodes:
|
|
nodes.append('none')
|
|
print(f'edges size: {edges.shape}, nodes size: {len(nodes)}')
|
|
return edges,nodes
|
|
|
|
|
|
def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
|
|
# def graph_to_graph_data(graph):
|
|
ops = graph[1]
|
|
adj = graph[0]
|
|
nodes = []
|
|
for op in ops:
|
|
nodes.append(op_type[op])
|
|
x = torch.LongTensor(nodes)
|
|
|
|
edges_list = []
|
|
edge_type = []
|
|
for start in range(len(ops)):
|
|
for end in range(len(ops)):
|
|
if adj[start][end] == 1:
|
|
edges_list.append((start, end))
|
|
edge_type.append(1)
|
|
edges_list.append((end, start))
|
|
edge_type.append(1)
|
|
|
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
|
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
|
edge_attr = edge_type
|
|
# y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
|
# y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device)
|
|
y = self.swap_scores[idx]
|
|
print(y, idx)
|
|
if y > 60000:
|
|
print(f'idx={idx}, y={y}')
|
|
y = torch.tensor([1, 1], dtype=torch.float).view(1, -1)
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
|
else:
|
|
print(f'idx={idx}, y={y}')
|
|
y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
|
# return None
|
|
return data
|
|
graph_list = []
|
|
class Args:
|
|
pass
|
|
args = Args()
|
|
args.trainval = True
|
|
args.augtype = 'none'
|
|
args.repeat = 1
|
|
args.score = 'hook_logdet'
|
|
args.sigma = 0.05
|
|
args.nasspace = 'nasbench201'
|
|
args.batch_size = 128
|
|
args.GPU = '0'
|
|
args.dataset = 'cifar10'
|
|
args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
|
args.data_loc = '../cifardata/'
|
|
args.seed = 777
|
|
args.init = ''
|
|
args.save_loc = 'results'
|
|
args.save_string = 'naswot'
|
|
args.dropout = False
|
|
args.maxofn = 1
|
|
args.n_samples = 100
|
|
args.n_runs = 500
|
|
args.stem_out_channels = 16
|
|
args.num_stacks = 3
|
|
args.num_modules_per_stack = 3
|
|
args.num_labels = 1
|
|
searchspace = nasspace.get_search_space(args)
|
|
train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
|
self.swap_scores = []
|
|
import csv
|
|
# with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
|
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
|
|
reader = csv.reader(f)
|
|
header = next(reader)
|
|
data = [row for row in reader]
|
|
self.swap_scores = [float(row[0]) for row in data]
|
|
device = torch.device('cuda:2')
|
|
with tqdm(total = len_data) as pbar:
|
|
active_nodes = set()
|
|
file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
|
|
with open(file_path, 'r') as f:
|
|
graph_list = json.load(f)
|
|
i = 0
|
|
flex_graph_list = []
|
|
flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
|
|
for graph in graph_list:
|
|
print(f'iterate every graph in graph_list, here is {i}')
|
|
arch_info = graph['arch_str']
|
|
ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4)
|
|
for op in ops:
|
|
if op not in active_nodes:
|
|
active_nodes.add(op)
|
|
data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)
|
|
i += 1
|
|
if data is None:
|
|
pbar.update(1)
|
|
continue
|
|
flex_graph_list.append({
|
|
'adj_matrix':adj_matrix,
|
|
'ops': ops,
|
|
})
|
|
if i < 3:
|
|
print(f"i={i}, data={data}")
|
|
with open(f'{i}.json', 'w') as f:
|
|
f.write(str(data.x))
|
|
f.write(str(data.edge_index))
|
|
f.write(str(data.edge_attr))
|
|
data_list.append(data)
|
|
|
|
# new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9, random_ratio=0.5)
|
|
# flex_graph_list.append({
|
|
# 'adj_matrix':new_adj.tolist(),
|
|
# 'ops': new_ops,
|
|
# })
|
|
# data_list.append(graph_to_graph_data((new_adj, new_ops)))
|
|
|
|
# graph_list.append({
|
|
# "adj_matrix": adj_matrix,
|
|
# "ops": ops,
|
|
# "arch_str": arch_info.arch_str,
|
|
# "idx": i,
|
|
# "train": [{
|
|
# "iepoch": result.get_train()['iepoch'],
|
|
# "loss": result.get_train()['loss'],
|
|
# "accuracy": result.get_train()['accuracy'],
|
|
# "cur_time": result.get_train()['cur_time'],
|
|
# "all_time": result.get_train()['all_time'],
|
|
# "seed": seed,
|
|
# }for seed, result in results.items()],
|
|
# "valid": [{
|
|
# "iepoch": result.get_eval('x-valid')['iepoch'],
|
|
# "loss": result.get_eval('x-valid')['loss'],
|
|
# "accuracy": result.get_eval('x-valid')['accuracy'],
|
|
# "cur_time": result.get_eval('x-valid')['cur_time'],
|
|
# "all_time": result.get_eval('x-valid')['all_time'],
|
|
# "seed": seed,
|
|
# }for seed, result in results.items()],
|
|
# "test": [{
|
|
# "iepoch": result.get_eval('x-test')['iepoch'],
|
|
# "loss": result.get_eval('x-test')['loss'],
|
|
# "accuracy": result.get_eval('x-test')['accuracy'],
|
|
# "cur_time": result.get_eval('x-test')['cur_time'],
|
|
# "all_time": result.get_eval('x-test')['all_time'],
|
|
# "seed": seed,
|
|
# }for seed, result in results.items()]
|
|
# })
|
|
# i += 1
|
|
pbar.update(1)
|
|
|
|
for graph in graph_list:
|
|
adj_matrix = graph['adj_matrix']
|
|
if isinstance(adj_matrix, np.ndarray):
|
|
adj_matrix = adj_matrix.tolist()
|
|
graph['adj_matrix'] = adj_matrix
|
|
ops = graph['ops']
|
|
if isinstance(ops, np.ndarray):
|
|
ops = ops.tolist()
|
|
graph['ops'] = ops
|
|
with open(f'nasbench-201-graph.json', 'w') as f:
|
|
json.dump(graph_list, f)
|
|
# with open(flex_graph_path, 'w') as f:
|
|
# json.dump(flex_graph_list, f)
|
|
|
|
torch.save(self.collate(data_list), self.processed_paths[0])
|
|
|
|
# def parse_architecture_string(arch_str):
|
|
# stages = arch_str.split('+')
|
|
# nodes = ['input']
|
|
# edges = []
|
|
|
|
# for stage in stages:
|
|
# operations = stage.strip('|').split('|')
|
|
# for op in operations:
|
|
# operation, idx = op.split('~')
|
|
# idx = int(idx)
|
|
# edges.append((idx, len(nodes))) # Add edge from idx to the new node
|
|
# nodes.append(operation)
|
|
# nodes.append('output') # Add the output node
|
|
# return nodes, edges
|
|
|
|
# def create_graph(nodes, edges):
|
|
# G = nx.DiGraph()
|
|
# for i, node in enumerate(nodes):
|
|
# G.add_node(i, label=node)
|
|
# G.add_edges_from(edges)
|
|
# return G
|
|
|
|
# def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
|
# nodes, edges = parse_architecture_string(arch_str)
|
|
|
|
# node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
|
# assert 0 not in node_labels, f'Invalid node label: {node_labels}'
|
|
# x = torch.LongTensor(node_labels)
|
|
# print(f'in initialize Dataset, arch_to_Graph x={x}')
|
|
|
|
# edges_list = [(start, end) for start, end in edges]
|
|
# edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
|
# edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
|
|
# edge_type = torch.tensor(edge_type, dtype=torch.long)
|
|
# edge_attr = edge_type.view(-1, 1)
|
|
|
|
# if target3 is not None:
|
|
# y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
|
|
# elif target2 is not None:
|
|
# y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
|
|
# else:
|
|
# y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
|
|
|
# print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
|
|
# data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
|
# return data, nodes
|
|
|
|
# bonds = {
|
|
# 'nor_conv_1x1': 1,
|
|
# 'nor_conv_3x3': 2,
|
|
# 'avg_pool_3x3': 3,
|
|
# 'skip_connect': 4,
|
|
# 'output': 5,
|
|
# 'none': 6,
|
|
# 'input': 7
|
|
# }
|
|
|
|
# # Prepare to process NAS-Bench-201 data
|
|
# data_list = []
|
|
# len_data = len(self.api) # Number of architectures
|
|
# with tqdm(total=len_data) as pbar:
|
|
# for arch_index in range(len_data):
|
|
# arch_info = self.api.query_meta_info_by_index(arch_index)
|
|
# arch_str = arch_info.arch_str
|
|
# sa = np.random.rand() # Placeholder for synthetic accessibility
|
|
# sc = np.random.rand() # Placeholder for substructure count
|
|
# target = np.random.rand() # Placeholder for target value
|
|
# target2 = np.random.rand() # Placeholder for second target value
|
|
# target3 = np.random.rand() # Placeholder for third target value
|
|
|
|
# data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
|
|
# data_list.append(data)
|
|
# pbar.update(1)
|
|
|
|
# torch.save(self.collate(data_list), self.processed_paths[0])
|
|
|
|
class Dataset_origin(InMemoryDataset):
|
|
def __init__(self, source, root, target_prop=None,
|
|
transform=None, pre_transform=None, pre_filter=None):
|
|
self.target_prop = target_prop
|
|
self.source = source
|
|
super().__init__(root, transform, pre_transform, pre_filter)
|
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
|
|
@property
|
|
def raw_file_names(self):
|
|
return [f'{self.source}.csv.gz']
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
return [f'{self.source}.pt']
|
|
|
|
def process(self):
|
|
RDLogger.DisableLog('rdApp.*')
|
|
data_path = osp.join(self.raw_dir, self.raw_file_names[0])
|
|
data_df = pd.read_csv(data_path)
|
|
|
|
def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
|
|
type_idx = []
|
|
heavy_atom_indices, active_atoms = [], []
|
|
for atom in mol.GetAtoms():
|
|
if atom.GetAtomicNum() != 1:
|
|
type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
|
|
heavy_atom_indices.append(atom.GetIdx())
|
|
active_atoms.append(atom.GetSymbol())
|
|
if valid_atoms is not None:
|
|
if not atom.GetSymbol() in valid_atoms:
|
|
return None, None
|
|
x = torch.LongTensor(type_idx)
|
|
|
|
edges_list = []
|
|
edge_type = []
|
|
for bond in mol.GetBonds():
|
|
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
|
if start in heavy_atom_indices and end in heavy_atom_indices:
|
|
start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
|
|
edges_list.append((start_new, end_new))
|
|
edge_type.append(bonds[bond.GetBondType()])
|
|
edges_list.append((end_new, start_new))
|
|
edge_type.append(bonds[bond.GetBondType()])
|
|
edge_index = torch.tensor(edges_list, dtype=torch.long).t()
|
|
edge_type = torch.tensor(edge_type, dtype=torch.long)
|
|
edge_attr = edge_type
|
|
|
|
if target3 is not None:
|
|
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
|
|
elif target2 is not None:
|
|
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
|
|
else:
|
|
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)
|
|
|
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
|
|
if self.pre_transform is not None:
|
|
data = self.pre_transform(data)
|
|
return data, active_atoms
|
|
|
|
# Loop through every row in the DataFrame and apply the function
|
|
data_list = []
|
|
len_data = len(data_df)
|
|
with tqdm(total=len_data) as pbar:
|
|
# --- data processing start ---
|
|
active_atoms = set()
|
|
for i, (sms, df_row) in enumerate(data_df.iterrows()):
|
|
if i == sms:
|
|
sms = df_row['smiles']
|
|
mol = Chem.MolFromSmiles(sms, sanitize=False)
|
|
if len(self.target_prop.split('-')) == 2:
|
|
target1, target2 = self.target_prop.split('-')
|
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
|
|
elif len(self.target_prop.split('-')) == 3:
|
|
target1, target2, target3 = self.target_prop.split('-')
|
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
|
|
else:
|
|
data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
|
|
active_atoms.update(cur_active_atoms)
|
|
data_list.append(data)
|
|
pbar.update(1)
|
|
|
|
torch.save(self.collate(data_list), self.processed_paths[0])
|
|
|
|
def parse_architecture_string(arch_str, padding=0):
|
|
# print(arch_str)
|
|
steps = arch_str.split('+')
|
|
nodes = ['input'] # Start with input node
|
|
ori_adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 1 ,0 ,0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
# [0, 0, 0, 0, 0, 0, 0, 0]])
|
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
|
# adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
|
|
adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 1 ,0 ,0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
# [0, 0, 0, 0, 0, 0, 0, 0]])
|
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
|
steps = arch_str.split('+')
|
|
steps_coding = ['0', '0', '1', '0', '1', '2']
|
|
cont = 0
|
|
for step in steps:
|
|
step = step.strip('|').split('|')
|
|
for node in step:
|
|
n, idx = node.split('~')
|
|
assert idx == steps_coding[cont]
|
|
cont += 1
|
|
nodes.append(n)
|
|
nodes.append('output') # Add output node
|
|
ori_nodes = nodes.copy()
|
|
if padding > 0:
|
|
for i in range(padding):
|
|
nodes.append('none')
|
|
for adj_row in adj_mat:
|
|
for i in range(padding):
|
|
adj_row.append(0)
|
|
# adj_mat = np.append(adj_mat, np.zeros((padding, len(nodes))))
|
|
for i in range(padding):
|
|
adj_mat.append([0] * len(nodes))
|
|
# print(nodes)
|
|
# print(adj_mat)
|
|
# print(len(adj_mat))
|
|
# print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}')
|
|
return nodes, adj_mat, ori_nodes, ori_adj_mat
|
|
|
|
def create_adj_matrix_and_ops(nodes, edges):
|
|
num_nodes = len(nodes)
|
|
adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
|
|
for (src, dst) in edges:
|
|
adj_matrix[src][dst] = 1
|
|
return adj_matrix, nodes
|
|
class DataInfos(AbstractDatasetInfos):
|
|
def __init__(self, datamodule, cfg, dataset):
|
|
tasktype_dict = {
|
|
'hiv_b': 'classification',
|
|
'bace_b': 'classification',
|
|
'bbbp_b': 'classification',
|
|
'O2': 'regression',
|
|
'N2': 'regression',
|
|
'CO2': 'regression',
|
|
}
|
|
task_name = cfg.dataset.task_name
|
|
self.task = task_name
|
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
|
self.ensure_connected = cfg.model.ensure_connected
|
|
# self.api = dataset.api
|
|
|
|
datadir = cfg.dataset.datadir
|
|
|
|
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
|
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
|
|
data_root = os.path.join(base_path, datadir, 'raw')
|
|
graphs = []
|
|
length = 15625
|
|
ops_type = {}
|
|
len_ops = set()
|
|
# api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
|
|
|
|
|
def read_adj_ops_from_json(filename):
|
|
with open(filename, 'r') as json_file:
|
|
data = json.load(json_file)
|
|
|
|
adj_ops_pairs = []
|
|
for item in data:
|
|
print(item)
|
|
adj_matrix = np.array(item['adj_matrix'])
|
|
ops = item['ops']
|
|
ops = [op_type[op] for op in ops]
|
|
adj_ops_pairs.append((adj_matrix, ops))
|
|
|
|
return adj_ops_pairs
|
|
# for i in range(length):
|
|
# arch_info = self.api.query_meta_info_by_index(i)
|
|
# nodes, edges = parse_architecture_string(arch_info.arch_str)
|
|
# adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
|
# if i < 5:
|
|
# print("Adjacency Matrix:")
|
|
# print(adj_matrix)
|
|
# print("Operations List:")
|
|
# print(ops)
|
|
# for op in ops:
|
|
# if op not in ops_type:
|
|
# ops_type[op] = len(ops_type)
|
|
# len_ops.add(len(ops))
|
|
# graphs.append((adj_matrix, ops))
|
|
# graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
|
|
graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')
|
|
|
|
# check first five graphs
|
|
for i in range(5):
|
|
print(f'graph {i} : {graphs[i]}')
|
|
# print(f'ops_type: {ops_type}')
|
|
|
|
meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
|
|
self.base_path = base_path
|
|
self.active_nodes = meta_dict['active_nodes']
|
|
self.max_n_nodes = meta_dict['max_n_nodes']
|
|
self.original_max_n_nodes = meta_dict['max_n_nodes']
|
|
self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph'])
|
|
self.edge_types = torch.Tensor(meta_dict['edge_type_list'])
|
|
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
|
|
|
self.node_decoder = meta_dict['active_nodes']
|
|
node_types = torch.Tensor(meta_dict['node_type_list'])
|
|
active_index = (node_types > 0).nonzero().squeeze()
|
|
self.node_types = torch.Tensor(meta_dict['node_type_list'])[active_index]
|
|
self.nodes_dist = DistributionNodes(self.n_nodes)
|
|
self.active_index = active_index
|
|
|
|
val_len = 3 * self.original_max_n_nodes - 2
|
|
meta_val = torch.Tensor(meta_dict['valencies'])
|
|
self.valency_distribution = torch.zeros(val_len)
|
|
val_len = min(val_len, len(meta_val))
|
|
self.valency_distribution[:val_len] = meta_val[:val_len]
|
|
self.y_prior = None
|
|
self.train_ymin = []
|
|
self.train_ymax = []
|
|
|
|
|
|
|
|
class DataInfos_origin(AbstractDatasetInfos):
|
|
def __init__(self, datamodule, cfg):
|
|
tasktype_dict = {
|
|
'hiv_b': 'classification',
|
|
'bace_b': 'classification',
|
|
'bbbp_b': 'classification',
|
|
'O2': 'regression',
|
|
'N2': 'regression',
|
|
'CO2': 'regression',
|
|
}
|
|
task_name = cfg.dataset.task_name
|
|
self.task = task_name
|
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
|
self.ensure_connected = cfg.model.ensure_connected
|
|
|
|
datadir = cfg.dataset.datadir
|
|
|
|
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
|
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
|
|
data_root = os.path.join(base_path, datadir, 'raw')
|
|
if os.path.exists(meta_filename):
|
|
with open(meta_filename, 'r') as f:
|
|
meta_dict = json.load(f)
|
|
else:
|
|
meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index)
|
|
|
|
self.base_path = base_path
|
|
self.active_atoms = meta_dict['active_atoms']
|
|
self.max_n_nodes = meta_dict['max_node']
|
|
self.original_max_n_nodes = meta_dict['max_node']
|
|
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
|
|
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
|
|
self.transition_E = torch.Tensor(meta_dict['transition_E'])
|
|
|
|
self.atom_decoder = meta_dict['active_atoms']
|
|
node_types = torch.Tensor(meta_dict['atom_type_dist'])
|
|
active_index = (node_types > 0).nonzero().squeeze()
|
|
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
|
|
self.nodes_dist = DistributionNodes(self.n_nodes)
|
|
self.active_index = active_index
|
|
|
|
val_len = 3 * self.original_max_n_nodes - 2
|
|
meta_val = torch.Tensor(meta_dict['valencies'])
|
|
self.valency_distribution = torch.zeros(val_len)
|
|
val_len = min(val_len, len(meta_val))
|
|
self.valency_distribution[:val_len] = meta_val[:val_len]
|
|
self.y_prior = None
|
|
self.train_ymin = []
|
|
self.train_ymax = []
|
|
|
|
|
|
def compute_meta(root, source_name, train_index, test_index):
|
|
# initialize the periodic table
|
|
# 118 elements + 1 for *
|
|
# Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
|
|
pt = Chem.GetPeriodicTable()
|
|
atom_name_list = []
|
|
atom_count_list = []
|
|
for i in range(2, 119):
|
|
atom_name_list.append(pt.GetElementSymbol(i))
|
|
atom_count_list.append(0)
|
|
atom_name_list.append('*')
|
|
atom_count_list.append(0)
|
|
n_atoms_per_mol = [0] * 500
|
|
bond_count_list = [0, 0, 0, 0, 0]
|
|
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
|
valencies = [0] * 500
|
|
tansition_E = np.zeros((118, 118, 5))
|
|
|
|
# Load the data from the source file
|
|
filename = f'{source_name}.csv.gz'
|
|
df = pd.read_csv(f'{root}/{filename}')
|
|
all_index = list(range(len(df)))
|
|
non_test_index = list(set(all_index) - set(test_index))
|
|
df = df.iloc[non_test_index]
|
|
# extract the smiles from the dataframe
|
|
tot_smiles = df['smiles'].tolist()
|
|
|
|
n_atom_list = []
|
|
n_bond_list = []
|
|
for i, sms in enumerate(tot_smiles):
|
|
try:
|
|
mol = Chem.MolFromSmiles(sms)
|
|
except:
|
|
continue
|
|
|
|
n_atom = mol.GetNumHeavyAtoms()
|
|
n_bond = mol.GetNumBonds()
|
|
n_atom_list.append(n_atom)
|
|
n_bond_list.append(n_bond)
|
|
|
|
n_atoms_per_mol[n_atom] += 1
|
|
cur_atom_count_arr = np.zeros(118)
|
|
for atom in mol.GetAtoms():
|
|
symbol = atom.GetSymbol()
|
|
if symbol == 'H':
|
|
continue
|
|
elif symbol == '*':
|
|
atom_count_list[-1] += 1
|
|
cur_atom_count_arr[-1] += 1
|
|
else:
|
|
atom_count_list[atom.GetAtomicNum()-2] += 1
|
|
cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
|
|
try:
|
|
valencies[int(atom.GetExplicitValence())] += 1
|
|
except:
|
|
print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
|
|
|
|
tansition_E_temp = np.zeros((118, 118, 5))
|
|
for bond in mol.GetBonds():
|
|
start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
|
|
if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
|
|
continue
|
|
|
|
if start_atom.GetSymbol() == '*':
|
|
start_index = 117
|
|
else:
|
|
start_index = start_atom.GetAtomicNum() - 2
|
|
if end_atom.GetSymbol() == '*':
|
|
end_index = 117
|
|
else:
|
|
end_index = end_atom.GetAtomicNum() - 2
|
|
|
|
bond_type = bond.GetBondType()
|
|
bond_index = bond_type_to_index[bond_type]
|
|
bond_count_list[bond_index] += 2
|
|
|
|
# Update the transition matrix
|
|
# The transition matrix is symmetric, so we update both directions
|
|
# We also update the temporary transition matrix to check for errors
|
|
# in the atom count
|
|
|
|
tansition_E[start_index, end_index, bond_index] += 2
|
|
tansition_E[end_index, start_index, bond_index] += 2
|
|
tansition_E_temp[start_index, end_index, bond_index] += 2
|
|
tansition_E_temp[end_index, start_index, bond_index] += 2
|
|
|
|
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
|
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
|
|
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
|
|
tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
|
|
assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
|
|
|
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
|
|
|
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
|
print('processed meta info: ------', filename, '------')
|
|
print('len atom_count_list', len(atom_count_list))
|
|
print('len atom_name_list', len(atom_name_list))
|
|
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
|
active_atoms = active_atoms.tolist()
|
|
atom_count_list = atom_count_list.tolist()
|
|
|
|
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
|
bond_count_list = bond_count_list.tolist()
|
|
valencies = np.array(valencies) / np.sum(valencies)
|
|
valencies = valencies.tolist()
|
|
|
|
no_edge = np.sum(tansition_E, axis=-1) == 0
|
|
first_elt = tansition_E[:, :, 0]
|
|
first_elt[no_edge] = 1
|
|
tansition_E[:, :, 0] = first_elt
|
|
|
|
tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
|
|
|
|
meta_dict = {
|
|
'source': source_name,
|
|
'num_graph': len(n_atom_list),
|
|
'n_atoms_per_mol_dist': n_atoms_per_mol,
|
|
'max_node': max(n_atom_list),
|
|
'max_bond': max(n_bond_list),
|
|
'atom_type_dist': atom_count_list,
|
|
'bond_type_dist': bond_count_list,
|
|
'valencies': valencies,
|
|
'active_atoms': active_atoms,
|
|
'num_atom_type': len(active_atoms),
|
|
'transition_E': tansition_E.tolist(),
|
|
}
|
|
|
|
with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
|
|
json.dump(meta_dict, f)
|
|
|
|
return meta_dict
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|