try to run the graph, commented sampling codes
This commit is contained in:
parent
e04ad5fbe7
commit
82299e5213
@ -41,6 +41,6 @@ train:
|
|||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
dataset:
|
dataset:
|
||||||
datadir: 'data/'
|
datadir: 'data/'
|
||||||
task_name: null
|
task_name: 'nasbench-201'
|
||||||
guidance_target: null
|
guidance_target: 'nasbench-201'
|
||||||
pin_memory: False
|
pin_memory: False
|
||||||
|
@ -116,7 +116,7 @@ class AbstractDatasetInfos:
|
|||||||
def compute_input_output_dims(self, datamodule):
|
def compute_input_output_dims(self, datamodule):
|
||||||
example_batch = datamodule.example_batch()
|
example_batch = datamodule.example_batch()
|
||||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
|
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
|
||||||
|
|
||||||
self.input_dims = {'X': example_batch_x.size(1),
|
self.input_dims = {'X': example_batch_x.size(1),
|
||||||
'E': example_batch_edge_attr.size(1),
|
'E': example_batch_edge_attr.size(1),
|
||||||
|
@ -81,6 +81,7 @@ class DataModule(AbstractDataModule):
|
|||||||
|
|
||||||
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
self.test_dataset = test_dataset
|
||||||
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||||
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||||
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||||
@ -93,8 +94,9 @@ class DataModule(AbstractDataModule):
|
|||||||
self.training_iterations = training_iterations
|
self.training_iterations = training_iterations
|
||||||
|
|
||||||
def random_data_split(self, dataset):
|
def random_data_split(self, dataset):
|
||||||
nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
|
# nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
|
||||||
labeled_len = len(dataset) - nan_count
|
# labeled_len = len(dataset) - nan_count
|
||||||
|
labeled_len = len(dataset)
|
||||||
full_idx = list(range(labeled_len))
|
full_idx = list(range(labeled_len))
|
||||||
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
||||||
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
||||||
@ -115,7 +117,7 @@ class DataModule(AbstractDataModule):
|
|||||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||||
return train_index, val_index, test_index, []
|
return train_index, val_index, test_index, []
|
||||||
|
|
||||||
def parse_architecture_string(arch_str):
|
def parse_architecture_string(self, arch_str):
|
||||||
stages = arch_str.split('+')
|
stages = arch_str.split('+')
|
||||||
nodes = ['input']
|
nodes = ['input']
|
||||||
edges = []
|
edges = []
|
||||||
@ -130,19 +132,39 @@ class DataModule(AbstractDataModule):
|
|||||||
nodes.append('output') # Add the output node
|
nodes.append('output') # Add the output node
|
||||||
return nodes, edges
|
return nodes, edges
|
||||||
|
|
||||||
def create_molecule_from_graph(nodes, edges):
|
# def create_molecule_from_graph(nodes, edges):
|
||||||
|
def create_molecule_from_graph(self, graph):
|
||||||
|
nodes = graph.x
|
||||||
|
edges = graph.edge_index
|
||||||
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
||||||
atom_indices = {}
|
atom_indices = {}
|
||||||
|
num_to_op = {
|
||||||
|
1 :'nor_conv_1x1',
|
||||||
|
2 :'nor_conv_3x3',
|
||||||
|
3 :'avg_pool_3x3',
|
||||||
|
4 :'skip_connect',
|
||||||
|
5 :'output',
|
||||||
|
6 :'none',
|
||||||
|
7 :'input'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract node operations from the data object
|
||||||
|
|
||||||
# Add atoms to the molecule
|
# Add atoms to the molecule
|
||||||
for i, node in enumerate(nodes):
|
for i, op_tensor in enumerate(nodes):
|
||||||
atom_symbol = op_to_atom[node]
|
op = op_tensor.item()
|
||||||
|
if op == 0: continue
|
||||||
|
op = num_to_op[op]
|
||||||
|
atom_symbol = op_to_atom[op]
|
||||||
atom = Chem.Atom(atom_symbol)
|
atom = Chem.Atom(atom_symbol)
|
||||||
atom_idx = mol.AddAtom(atom)
|
atom_idx = mol.AddAtom(atom)
|
||||||
atom_indices[i] = atom_idx
|
atom_indices[i] = atom_idx
|
||||||
|
|
||||||
# Add bonds to the molecule
|
# Add bonds to the molecule
|
||||||
for start, end in edges:
|
edge_number = edges.shape[1]
|
||||||
|
for i in range(edge_number):
|
||||||
|
start = edges[0, i].item()
|
||||||
|
end = edges[1, i].item()
|
||||||
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
||||||
|
|
||||||
return mol
|
return mol
|
||||||
@ -154,30 +176,23 @@ class DataModule(AbstractDataModule):
|
|||||||
return smiles
|
return smiles
|
||||||
|
|
||||||
def get_train_smiles(self):
|
def get_train_smiles(self):
|
||||||
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
|
||||||
# train_arch_strs = []
|
|
||||||
# test_arch_strs = []
|
|
||||||
|
|
||||||
# for idx in self.train_index:
|
|
||||||
# arch_info = self.train_dataset[idx]
|
|
||||||
# arch_str = arch_info.arch_str
|
|
||||||
# train_arch_strs.append(arch_str)
|
|
||||||
# for idx in self.test_index:
|
|
||||||
# arch_info = self.train_dataset[idx]
|
|
||||||
# arch_str = arch_info.arch_str
|
|
||||||
# test_arch_strs.append(arch_str)
|
|
||||||
|
|
||||||
train_smiles = []
|
train_smiles = []
|
||||||
test_smiles = []
|
test_smiles = []
|
||||||
|
|
||||||
for idx in self.train_index:
|
for graph in self.train_dataset:
|
||||||
graph = self.train_dataset[idx]
|
# print(f'idx={idx}')
|
||||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
# graph = self.train_dataset[idx]
|
||||||
|
print(graph.x)
|
||||||
|
print(graph.edge_index)
|
||||||
|
print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
|
||||||
|
mol = self.create_molecule_from_graph(graph)
|
||||||
train_smiles.append(Chem.MolToSmiles(mol))
|
train_smiles.append(Chem.MolToSmiles(mol))
|
||||||
|
|
||||||
for idx in self.test_index:
|
# for idx in self.test_index:
|
||||||
graph = self.train_dataset[idx]
|
for graph in self.test_dataset:
|
||||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
# graph = self.dataset[idx]
|
||||||
|
# mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||||
|
mol = self.create_molecule_from_graph(graph)
|
||||||
test_smiles.append(Chem.MolToSmiles(mol))
|
test_smiles.append(Chem.MolToSmiles(mol))
|
||||||
|
|
||||||
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
||||||
@ -199,161 +214,8 @@ class DataModule(AbstractDataModule):
|
|||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return self.test_loader
|
return self.test_loader
|
||||||
|
|
||||||
def graphs_to_json(graphs, filename):
|
|
||||||
bonds = {
|
|
||||||
'nor_conv_1x1': 1,
|
|
||||||
'nor_conv_3x3': 2,
|
|
||||||
'avg_pool_3x3': 3,
|
|
||||||
'skip_connect': 4,
|
|
||||||
'input': 7,
|
|
||||||
'output': 5,
|
|
||||||
'none': 6
|
|
||||||
}
|
|
||||||
|
|
||||||
source_name = "nas-bench-201"
|
|
||||||
num_graph = len(graphs)
|
|
||||||
pt = Chem.GetPeriodicTable()
|
|
||||||
atom_name_list = []
|
|
||||||
atom_count_list = []
|
|
||||||
for i in range(2, 119):
|
|
||||||
atom_name_list.append(pt.GetElementSymbol(i))
|
|
||||||
atom_count_list.append(0)
|
|
||||||
atom_name_list.append('*')
|
|
||||||
atom_count_list.append(0)
|
|
||||||
n_atoms_per_mol = [0] * 500
|
|
||||||
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
|
||||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
|
||||||
valencies = [0] * 500
|
|
||||||
transition_E = np.zeros((118, 118, 5))
|
|
||||||
|
|
||||||
n_atom_list = []
|
|
||||||
n_bond_list = []
|
|
||||||
# graphs = [(adj_matrix, ops), ...]
|
|
||||||
for graph in graphs:
|
|
||||||
ops = graph[1]
|
|
||||||
adj = graph[0]
|
|
||||||
n_atom = len(ops)
|
|
||||||
n_bond = len(ops)
|
|
||||||
n_atom_list.append(n_atom)
|
|
||||||
n_bond_list.append(n_bond)
|
|
||||||
|
|
||||||
n_atoms_per_mol[n_atom] += 1
|
|
||||||
cur_atom_count_arr = np.zeros(118)
|
|
||||||
|
|
||||||
for op in ops:
|
|
||||||
symbol = op_to_atom[op]
|
|
||||||
if symbol == 'H':
|
|
||||||
continue
|
|
||||||
elif symbol == '*':
|
|
||||||
atom_count_list[-1] += 1
|
|
||||||
cur_atom_count_arr[-1] += 1
|
|
||||||
else:
|
|
||||||
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
|
|
||||||
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
|
|
||||||
# print('symbol', symbol)
|
|
||||||
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
|
|
||||||
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
|
|
||||||
try:
|
|
||||||
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
|
||||||
except:
|
|
||||||
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
|
||||||
transition_E_temp = np.zeros((118, 118, 5))
|
|
||||||
# print(n_atom)
|
|
||||||
for i in range(n_atom):
|
|
||||||
for j in range(n_atom):
|
|
||||||
if i == j or adj[i][j] == 0:
|
|
||||||
continue
|
|
||||||
start_atom, end_atom = i, j
|
|
||||||
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
|
|
||||||
continue
|
|
||||||
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
|
|
||||||
continue
|
|
||||||
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
|
|
||||||
continue
|
|
||||||
|
|
||||||
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
|
|
||||||
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
|
|
||||||
bond_index = bonds[ops[end_atom]]
|
|
||||||
bond_count_list[bond_index] += 2
|
|
||||||
|
|
||||||
# print(start_index, end_index, bond_index)
|
|
||||||
|
|
||||||
transition_E[start_index, end_index, bond_index] += 2
|
|
||||||
transition_E[end_index, start_index, bond_index] += 2
|
|
||||||
transition_E_temp[start_index, end_index, bond_index] += 2
|
|
||||||
transition_E_temp[end_index, start_index, bond_index] += 2
|
|
||||||
|
|
||||||
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
|
||||||
print(bond_count_list)
|
|
||||||
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
|
|
||||||
# print(f'cur_tot_bond={cur_tot_bond}')
|
|
||||||
# find non-zero element in cur_tot_bond
|
|
||||||
# for i in range(118):
|
|
||||||
# for j in range(118):
|
|
||||||
# if cur_tot_bond[i][j] != 0:
|
|
||||||
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
|
|
||||||
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
||||||
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
|
|
||||||
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
|
|
||||||
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
|
|
||||||
# find non-zero element in transition_E
|
|
||||||
# for i in range(118):
|
|
||||||
# for j in range(118):
|
|
||||||
# if transition_E[i][j][0] != 0:
|
|
||||||
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
|
|
||||||
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
|
||||||
|
|
||||||
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
|
||||||
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
|
||||||
|
|
||||||
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
|
||||||
print('processed meta info: ------', filename, '------')
|
|
||||||
print('len atom_count_list', len(atom_count_list))
|
|
||||||
print('len atom_name_list', len(atom_name_list))
|
|
||||||
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
|
||||||
active_atoms = active_atoms.tolist()
|
|
||||||
atom_count_list = atom_count_list.tolist()
|
|
||||||
|
|
||||||
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
|
||||||
bond_count_list = bond_count_list.tolist()
|
|
||||||
valencies = np.array(valencies) / np.sum(valencies)
|
|
||||||
valencies = valencies.tolist()
|
|
||||||
|
|
||||||
no_edge = np.sum(transition_E, axis=-1) == 0
|
|
||||||
for i in range(118):
|
|
||||||
for j in range(118):
|
|
||||||
if no_edge[i][j] == False:
|
|
||||||
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
|
|
||||||
# print(f'no_edge: {no_edge}')
|
|
||||||
first_elt = transition_E[:, :, 0]
|
|
||||||
first_elt[no_edge] = 1
|
|
||||||
transition_E[:, :, 0] = first_elt
|
|
||||||
|
|
||||||
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
|
|
||||||
|
|
||||||
# find non-zero element in transition_E again
|
|
||||||
for i in range(118):
|
|
||||||
for j in range(118):
|
|
||||||
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
|
|
||||||
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
|
|
||||||
|
|
||||||
meta_dict = {
|
|
||||||
'source': 'nasbench-201',
|
|
||||||
'num_graph': num_graph,
|
|
||||||
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
|
|
||||||
'max_node': max(n_atom_list),
|
|
||||||
'max_bond': max(n_bond_list),
|
|
||||||
'atom_type_dist': atom_count_list,
|
|
||||||
'bond_type_dist': bond_count_list,
|
|
||||||
'valencies': valencies,
|
|
||||||
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
|
|
||||||
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
|
|
||||||
'transition_E': transition_E.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(f'{filename}.meta.json', 'w') as f:
|
|
||||||
json.dump(meta_dict, f)
|
|
||||||
return meta_dict
|
|
||||||
|
|
||||||
class DataModule_original(AbstractDataModule):
|
class DataModule_original(AbstractDataModule):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
@ -482,7 +344,7 @@ def graphs_to_json(graphs, filename):
|
|||||||
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||||
valencies = [0] * 500
|
valencies = [0] * 500
|
||||||
transition_E = np.zeros((118, 118, 5))
|
transition_E = np.zeros((118, 118, 8))
|
||||||
|
|
||||||
n_atom_list = []
|
n_atom_list = []
|
||||||
n_bond_list = []
|
n_bond_list = []
|
||||||
@ -515,7 +377,7 @@ def graphs_to_json(graphs, filename):
|
|||||||
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
||||||
except:
|
except:
|
||||||
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
||||||
transition_E_temp = np.zeros((118, 118, 5))
|
transition_E_temp = np.zeros((118, 118, 8))
|
||||||
# print(n_atom)
|
# print(n_atom)
|
||||||
for i in range(n_atom):
|
for i in range(n_atom):
|
||||||
for j in range(n_atom):
|
for j in range(n_atom):
|
||||||
@ -612,14 +474,16 @@ def graphs_to_json(graphs, filename):
|
|||||||
with open(f'{filename}.meta.json', 'w') as f:
|
with open(f'{filename}.meta.json', 'w') as f:
|
||||||
json.dump(meta_dict, f)
|
json.dump(meta_dict, f)
|
||||||
return meta_dict
|
return meta_dict
|
||||||
|
|
||||||
class Dataset(InMemoryDataset):
|
class Dataset(InMemoryDataset):
|
||||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||||
self.target_prop = target_prop
|
self.target_prop = target_prop
|
||||||
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
self.source = source
|
self.source = source
|
||||||
self.api = API(source) # Initialize NAS-Bench-201 API
|
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
|
print('API loaded')
|
||||||
super().__init__(root, transform, pre_transform, pre_filter)
|
super().__init__(root, transform, pre_transform, pre_filter)
|
||||||
|
print('Dataset initialized')
|
||||||
|
print(self.processed_paths[0])
|
||||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -655,8 +519,11 @@ class Dataset(InMemoryDataset):
|
|||||||
|
|
||||||
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
||||||
nodes, edges = parse_architecture_string(arch_str)
|
nodes, edges = parse_architecture_string(arch_str)
|
||||||
|
|
||||||
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
||||||
|
assert 0 not in node_labels, f'Invalid node label: {node_labels}'
|
||||||
x = torch.LongTensor(node_labels)
|
x = torch.LongTensor(node_labels)
|
||||||
|
print(f'in initialize Dataset, arch_to_Graph x={x}')
|
||||||
|
|
||||||
edges_list = [(start, end) for start, end in edges]
|
edges_list = [(start, end) for start, end in edges]
|
||||||
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
||||||
@ -671,6 +538,7 @@ class Dataset(InMemoryDataset):
|
|||||||
else:
|
else:
|
||||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
||||||
|
|
||||||
|
print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
|
||||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
||||||
return data, nodes
|
return data, nodes
|
||||||
|
|
||||||
@ -679,9 +547,9 @@ class Dataset(InMemoryDataset):
|
|||||||
'nor_conv_3x3': 2,
|
'nor_conv_3x3': 2,
|
||||||
'avg_pool_3x3': 3,
|
'avg_pool_3x3': 3,
|
||||||
'skip_connect': 4,
|
'skip_connect': 4,
|
||||||
'input': 7,
|
|
||||||
'output': 5,
|
'output': 5,
|
||||||
'none': 6
|
'none': 6,
|
||||||
|
'input': 7
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prepare to process NAS-Bench-201 data
|
# Prepare to process NAS-Bench-201 data
|
||||||
|
@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
|||||||
import utils
|
import utils
|
||||||
|
|
||||||
class Graph_DiT(pl.LightningModule):
|
class Graph_DiT(pl.LightningModule):
|
||||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||||
|
def __init__(self, cfg, dataset_infos, visualization_tools):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||||
self.test_only = cfg.general.test_only
|
self.test_only = cfg.general.test_only
|
||||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||||
|
|
||||||
@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.test_E_logp = SumExceptBatchMetric()
|
self.test_E_logp = SumExceptBatchMetric()
|
||||||
self.test_y_collection = []
|
self.test_y_collection = []
|
||||||
|
|
||||||
self.train_metrics = train_metrics
|
# self.train_metrics = train_metrics
|
||||||
self.sampling_metrics = sampling_metrics
|
# self.sampling_metrics = sampling_metrics
|
||||||
|
|
||||||
self.visualization_tools = visualization_tools
|
self.visualization_tools = visualization_tools
|
||||||
self.max_n_nodes = dataset_infos.max_n_nodes
|
self.max_n_nodes = dataset_infos.max_n_nodes
|
||||||
@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
self.val_E_kl.reset()
|
self.val_E_kl.reset()
|
||||||
self.val_X_logp.reset()
|
self.val_X_logp.reset()
|
||||||
self.val_E_logp.reset()
|
self.val_E_logp.reset()
|
||||||
self.sampling_metrics.reset()
|
# self.sampling_metrics.reset()
|
||||||
self.val_y_collection = []
|
self.val_y_collection = []
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
|
|||||||
samples_left_to_generate -= to_generate
|
samples_left_to_generate -= to_generate
|
||||||
chains_left_to_save -= chains_save
|
chains_left_to_save -= chains_save
|
||||||
|
|
||||||
print(f"Computing sampling metrics", ' ...')
|
# print(f"Computing sampling metrics", ' ...')
|
||||||
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
# valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||||
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
# print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||||
|
|
||||||
current_path = os.getcwd()
|
current_path = os.getcwd()
|
||||||
result_path = os.path.join(current_path,
|
result_path = os.path.join(current_path,
|
||||||
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||||
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
# self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||||
self.sampling_metrics.reset()
|
# self.sampling_metrics.reset()
|
||||||
|
|
||||||
def on_test_epoch_start(self) -> None:
|
def on_test_epoch_start(self) -> None:
|
||||||
print("Starting test...")
|
print("Starting test...")
|
||||||
|
@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
|
|||||||
# Fetch path to this file to get base path
|
# Fetch path to this file to get base path
|
||||||
current_path = os.path.dirname(os.path.realpath(__file__))
|
current_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
root_dir = current_path.split("outputs")[0]
|
root_dir = current_path.split("outputs")[0]
|
||||||
|
|
||||||
resume_path = os.path.join(root_dir, cfg.general.resume)
|
resume_path = os.path.join(root_dir, cfg.general.resume)
|
||||||
|
|
||||||
if cfg.model.type == "discrete":
|
if cfg.model.type == "discrete":
|
||||||
@ -80,21 +79,21 @@ def main(cfg: DictConfig):
|
|||||||
datamodule = dataset.DataModule(cfg)
|
datamodule = dataset.DataModule(cfg)
|
||||||
datamodule.prepare_data()
|
datamodule.prepare_data()
|
||||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
||||||
train_smiles, reference_smiles = datamodule.get_train_smiles()
|
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||||
|
|
||||||
# get input output dimensions
|
# get input output dimensions
|
||||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||||
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||||
|
|
||||||
sampling_metrics = SamplingMolecularMetrics(
|
# sampling_metrics = SamplingMolecularMetrics(
|
||||||
dataset_infos, train_smiles, reference_smiles
|
# dataset_infos, train_smiles, reference_smiles
|
||||||
)
|
# )
|
||||||
visualization_tools = MolecularVisualization(dataset_infos)
|
visualization_tools = MolecularVisualization(dataset_infos)
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"dataset_infos": dataset_infos,
|
"dataset_infos": dataset_infos,
|
||||||
"train_metrics": train_metrics,
|
# "train_metrics": train_metrics,
|
||||||
"sampling_metrics": sampling_metrics,
|
# "sampling_metrics": sampling_metrics,
|
||||||
"visualization_tools": visualization_tools,
|
"visualization_tools": visualization_tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,9 +109,10 @@ def main(cfg: DictConfig):
|
|||||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
gradient_clip_val=cfg.train.clip_grad,
|
gradient_clip_val=cfg.train.clip_grad,
|
||||||
accelerator="gpu"
|
# accelerator="gpu"
|
||||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
# if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||||
else "cpu",
|
# else "cpu",
|
||||||
|
accelerator="cpu",
|
||||||
devices=cfg.general.gpus
|
devices=cfg.general.gpus
|
||||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||||
else None,
|
else None,
|
||||||
|
Loading…
Reference in New Issue
Block a user