try to transfer the code from jupyter notebook to dataset.py

This commit is contained in:
Hanzhang Ma 2024-06-11 17:48:25 +02:00
parent 2674a40b74
commit 99163a5150

View File

@ -2,6 +2,8 @@
import sys
sys.path.append('../')
from nas_201_api import NASBench201API as API
import os
import os.path as osp
import pathlib
@ -24,7 +26,266 @@ from diffusion.distributions import DistributionNodes
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
op_to_atom = {
'input': 'Si', # Hydrogen for input
'nor_conv_1x1': 'C', # Carbon for 1x1 convolution
'nor_conv_3x3': 'N', # Nitrogen for 3x3 convolution
'avg_pool_3x3': 'O', # Oxygen for 3x3 average pooling
'skip_connect': 'P', # Phosphorus for skip connection
'none': 'S', # Sulfur for no operation
'output': 'He' # Helium for output
}
class DataModule(AbstractDataModule):
def __init__(self, cfg):
self.datadir = cfg.dataset.datadir
self.task = cfg.dataset.task_name
print("DataModule")
print("task", self.task)
print("datadir", self.datadir)
super().__init__(cfg)
def prepare_data(self) -> None:
target = getattr(self.cfg.dataset, 'guidance_target', None)
print("target", target)
# try:
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
# except NameError:
# base_path = pathlib.Path(os.getcwd()).parent[2]
base_path = '/home/stud/hanzhang/Graph-Dit'
root_path = os.path.join(base_path, self.datadir)
self.root_path = root_path
batch_size = self.cfg.train.batch_size
num_workers = self.cfg.train.num_workers
pin_memory = self.cfg.dataset.pin_memory
# Load the dataset to the memory
# Dataset has target property, root path, and transform
source = './NAS-Bench-201-v1_1-096897.pth'
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
# if len(self.task.split('-')) == 2:
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
# else:
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
if len(unlabeled_index) > 0:
train_index = torch.cat([train_index, unlabeled_index], dim=0)
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
self.train_dataset = train_dataset
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)
self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
training_iterations = len(train_dataset) // batch_size
self.training_iterations = training_iterations
def random_data_split(self, dataset):
nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
labeled_len = len(dataset) - nan_count
full_idx = list(range(labeled_len))
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
unlabeled_index = list(range(labeled_len, len(dataset)))
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
return train_index, val_index, test_index, unlabeled_index
def fixed_split(self, dataset):
if self.task == 'O2-N2':
test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
else:
raise ValueError('Invalid task name: {}'.format(self.task))
full_idx = list(range(len(dataset)))
full_idx = list(set(full_idx) - set(test_index))
train_ratio = 0.8
train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, []
def get_train_smiles(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
def get_data_split(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
def example_batch(self):
return next(iter(self.val_loader))
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.val_loader
def test_dataloader(self):
return self.test_loader
def graphs_to_json(graphs, filename):
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
source_name = "nas-bench-201"
num_graph = len(graphs)
pt = Chem.GetPeriodicTable()
atom_name_list = []
atom_count_list = []
for i in range(2, 119):
atom_name_list.append(pt.GetElementSymbol(i))
atom_count_list.append(0)
atom_name_list.append('*')
atom_count_list.append(0)
n_atoms_per_mol = [0] * 500
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500
transition_E = np.zeros((118, 118, 5))
n_atom_list = []
n_bond_list = []
# graphs = [(adj_matrix, ops), ...]
for graph in graphs:
ops = graph[1]
adj = graph[0]
n_atom = len(ops)
n_bond = len(ops)
n_atom_list.append(n_atom)
n_bond_list.append(n_bond)
n_atoms_per_mol[n_atom] += 1
cur_atom_count_arr = np.zeros(118)
for op in ops:
symbol = op_to_atom[op]
if symbol == 'H':
continue
elif symbol == '*':
atom_count_list[-1] += 1
cur_atom_count_arr[-1] += 1
else:
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
# print('symbol', symbol)
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
try:
valencies[int(pt.GetDefaultValence(symbol))] += 1
except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5))
# print(n_atom)
for i in range(n_atom):
for j in range(n_atom):
if i == j or adj[i][j] == 0:
continue
start_atom, end_atom = i, j
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
continue
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
continue
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
continue
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
bond_index = bonds[ops[end_atom]]
bond_count_list[bond_index] += 2
# print(start_index, end_index, bond_index)
transition_E[start_index, end_index, bond_index] += 2
transition_E[end_index, start_index, bond_index] += 2
transition_E_temp[start_index, end_index, bond_index] += 2
transition_E_temp[end_index, start_index, bond_index] += 2
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
print(bond_count_list)
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
# print(f'cur_tot_bond={cur_tot_bond}')
# find non-zero element in cur_tot_bond
# for i in range(118):
# for j in range(118):
# if cur_tot_bond[i][j] != 0:
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
# find non-zero element in transition_E
# for i in range(118):
# for j in range(118):
# if transition_E[i][j][0] != 0:
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
print('processed meta info: ------', filename, '------')
print('len atom_count_list', len(atom_count_list))
print('len atom_name_list', len(atom_name_list))
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
active_atoms = active_atoms.tolist()
atom_count_list = atom_count_list.tolist()
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
bond_count_list = bond_count_list.tolist()
valencies = np.array(valencies) / np.sum(valencies)
valencies = valencies.tolist()
no_edge = np.sum(transition_E, axis=-1) == 0
for i in range(118):
for j in range(118):
if no_edge[i][j] == False:
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
# print(f'no_edge: {no_edge}')
first_elt = transition_E[:, :, 0]
first_elt[no_edge] = 1
transition_E[:, :, 0] = first_elt
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
# find non-zero element in transition_E again
for i in range(118):
for j in range(118):
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
meta_dict = {
'source': 'nasbench-201',
'num_graph': num_graph,
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
'max_node': max(n_atom_list),
'max_bond': max(n_bond_list),
'atom_type_dist': atom_count_list,
'bond_type_dist': bond_count_list,
'valencies': valencies,
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
'transition_E': transition_E.tolist(),
}
with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
class DataModule_original(AbstractDataModule):
def __init__(self, cfg):
self.datadir = cfg.dataset.datadir
self.task = cfg.dataset.task_name
@ -48,18 +309,6 @@ class DataModule(AbstractDataModule):
# Load the dataset to the memory
# Dataset has target property, root path, and transform
dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)
print("len dataset", len(dataset))
def print_data(dataset):
print("dataset", dataset)
print("dataset keys", dataset.keys)
print("dataset x", dataset.x)
print("dataset edge_index", dataset.edge_index)
print("dataset edge_attr", dataset.edge_attr)
print("dataset y", dataset.y)
print("")
print_data(dataset=dataset[0])
print_data(dataset=dataset[1])
if len(self.task.split('-')) == 2:
train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
@ -138,8 +387,163 @@ class DataModule(AbstractDataModule):
def test_dataloader(self):
return self.test_loader
def graphs_to_json(graphs, filename):
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
class Dataset(InMemoryDataset):
source_name = "nas-bench-201"
num_graph = len(graphs)
pt = Chem.GetPeriodicTable()
atom_name_list = []
atom_count_list = []
for i in range(2, 119):
atom_name_list.append(pt.GetElementSymbol(i))
atom_count_list.append(0)
atom_name_list.append('*')
atom_count_list.append(0)
n_atoms_per_mol = [0] * 500
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500
transition_E = np.zeros((118, 118, 5))
n_atom_list = []
n_bond_list = []
# graphs = [(adj_matrix, ops), ...]
for graph in graphs:
ops = graph[1]
adj = graph[0]
n_atom = len(ops)
n_bond = len(ops)
n_atom_list.append(n_atom)
n_bond_list.append(n_bond)
n_atoms_per_mol[n_atom] += 1
cur_atom_count_arr = np.zeros(118)
for op in ops:
symbol = op_to_atom[op]
if symbol == 'H':
continue
elif symbol == '*':
atom_count_list[-1] += 1
cur_atom_count_arr[-1] += 1
else:
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
# print('symbol', symbol)
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
try:
valencies[int(pt.GetDefaultValence(symbol))] += 1
except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5))
# print(n_atom)
for i in range(n_atom):
for j in range(n_atom):
if i == j or adj[i][j] == 0:
continue
start_atom, end_atom = i, j
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
continue
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
continue
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
continue
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
bond_index = bonds[ops[end_atom]]
bond_count_list[bond_index] += 2
# print(start_index, end_index, bond_index)
transition_E[start_index, end_index, bond_index] += 2
transition_E[end_index, start_index, bond_index] += 2
transition_E_temp[start_index, end_index, bond_index] += 2
transition_E_temp[end_index, start_index, bond_index] += 2
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
print(bond_count_list)
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
# print(f'cur_tot_bond={cur_tot_bond}')
# find non-zero element in cur_tot_bond
# for i in range(118):
# for j in range(118):
# if cur_tot_bond[i][j] != 0:
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
# find non-zero element in transition_E
# for i in range(118):
# for j in range(118):
# if transition_E[i][j][0] != 0:
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
print('processed meta info: ------', filename, '------')
print('len atom_count_list', len(atom_count_list))
print('len atom_name_list', len(atom_name_list))
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
active_atoms = active_atoms.tolist()
atom_count_list = atom_count_list.tolist()
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
bond_count_list = bond_count_list.tolist()
valencies = np.array(valencies) / np.sum(valencies)
valencies = valencies.tolist()
no_edge = np.sum(transition_E, axis=-1) == 0
for i in range(118):
for j in range(118):
if no_edge[i][j] == False:
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
# print(f'no_edge: {no_edge}')
first_elt = transition_E[:, :, 0]
first_elt[no_edge] = 1
transition_E[:, :, 0] = first_elt
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
# find non-zero element in transition_E again
for i in range(118):
for j in range(118):
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
meta_dict = {
'source': 'nasbench-201',
'num_graph': num_graph,
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
'max_node': max(n_atom_list),
'max_bond': max(n_bond_list),
'atom_type_dist': atom_count_list,
'bond_type_dist': bond_count_list,
'valencies': valencies,
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
'transition_E': transition_E.tolist(),
}
with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
class Dataset_origin(InMemoryDataset):
def __init__(self, source, root, target_prop=None,
transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
@ -223,8 +627,95 @@ class Dataset(InMemoryDataset):
torch.save(self.collate(data_list), self.processed_paths[0])
def parse_architecture_string(arch_str):
print(arch_str)
steps = arch_str.split('+')
nodes = ['input'] # Start with input node
edges = []
for i, step in enumerate(steps):
step = step.strip('|').split('|')
for node in step:
op, idx = node.split('~')
edges.append((int(idx), i+1)) # i+1 because 0 is input node
nodes.append(op)
nodes.append('output') # Add output node
return nodes, edges
def create_adj_matrix_and_ops(nodes, edges):
num_nodes = len(nodes)
adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
for (src, dst) in edges:
adj_matrix[src][dst] = 1
return adj_matrix, nodes
class DataInfos(AbstractDatasetInfos):
def __init__(self, datamodule, cfg):
tasktype_dict = {
'hiv_b': 'classification',
'bace_b': 'classification',
'bbbp_b': 'classification',
'O2': 'regression',
'N2': 'regression',
'CO2': 'regression',
}
task_name = cfg.dataset.task_name
self.task = task_name
self.task_type = tasktype_dict.get(task_name, "regression")
self.ensure_connected = cfg.model.ensure_connected
datadir = cfg.dataset.datadir
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
data_root = os.path.join(base_path, datadir, 'raw')
graphs = []
length = 15625
ops_type = {}
len_ops = set()
api = API('../NAS-Bench-201-v1_0-e61699.pth')
for i in range(length):
arch_info = api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str)
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
if i < 5:
print("Adjacency Matrix:")
print(adj_matrix)
print("Operations List:")
print(ops)
for op in ops:
if op not in ops_type:
ops_type[op] = len(ops_type)
len_ops.add(len(ops))
graphs.append((adj_matrix, ops))
meta_dict = graphs_to_json(graphs, 'nasbench-201')
self.base_path = base_path
self.active_atoms = meta_dict['active_atoms']
self.max_n_nodes = meta_dict['max_node']
self.original_max_n_nodes = meta_dict['max_node']
self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
self.transition_E = torch.Tensor(meta_dict['transition_E'])
self.atom_decoder = meta_dict['active_atoms']
node_types = torch.Tensor(meta_dict['atom_type_dist'])
active_index = (node_types > 0).nonzero().squeeze()
self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
self.nodes_dist = DistributionNodes(self.n_nodes)
self.active_index = active_index
val_len = 3 * self.original_max_n_nodes - 2
meta_val = torch.Tensor(meta_dict['valencies'])
self.valency_distribution = torch.zeros(val_len)
val_len = min(val_len, len(meta_val))
self.valency_distribution[:val_len] = meta_val[:val_len]
self.y_prior = None
self.train_ymin = []
self.train_ymax = []
class DataInfos_origin(AbstractDatasetInfos):
def __init__(self, datamodule, cfg):
tasktype_dict = {
'hiv_b': 'classification',