Compare commits

...

2 Commits

Author SHA1 Message Date
mhz
82299e5213 try to run the graph, commented sampling codes 2024-06-25 00:09:27 +02:00
Hanzhang Ma
e04ad5fbe7 need to run the jupyternotebook 2024-06-12 17:56:08 +02:00
6 changed files with 428 additions and 46263 deletions

View File

@ -41,6 +41,6 @@ train:
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
dataset: dataset:
datadir: 'data/' datadir: 'data/'
task_name: null task_name: 'nasbench-201'
guidance_target: null guidance_target: 'nasbench-201'
pin_memory: False pin_memory: False

View File

@ -116,7 +116,7 @@ class AbstractDatasetInfos:
def compute_input_output_dims(self, datamodule): def compute_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch() example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
self.input_dims = {'X': example_batch_x.size(1), self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1), 'E': example_batch_edge_attr.size(1),

View File

@ -13,6 +13,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from rdkit import Chem, RDLogger from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import rdchem
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -24,6 +25,9 @@ import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes from diffusion.distributions import DistributionNodes
import networkx as nx
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
op_to_atom = { op_to_atom = {
@ -77,6 +81,7 @@ class DataModule(AbstractDataModule):
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.test_dataset = test_dataset
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
@ -89,8 +94,9 @@ class DataModule(AbstractDataModule):
self.training_iterations = training_iterations self.training_iterations = training_iterations
def random_data_split(self, dataset): def random_data_split(self, dataset):
nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
labeled_len = len(dataset) - nan_count # labeled_len = len(dataset) - nan_count
labeled_len = len(dataset)
full_idx = list(range(labeled_len)) full_idx = list(range(labeled_len))
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
@ -111,8 +117,87 @@ class DataModule(AbstractDataModule):
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, [] return train_index, val_index, test_index, []
def parse_architecture_string(self, arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
# def create_molecule_from_graph(nodes, edges):
def create_molecule_from_graph(self, graph):
nodes = graph.x
edges = graph.edge_index
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
atom_indices = {}
num_to_op = {
1 :'nor_conv_1x1',
2 :'nor_conv_3x3',
3 :'avg_pool_3x3',
4 :'skip_connect',
5 :'output',
6 :'none',
7 :'input'
}
# Extract node operations from the data object
# Add atoms to the molecule
for i, op_tensor in enumerate(nodes):
op = op_tensor.item()
if op == 0: continue
op = num_to_op[op]
atom_symbol = op_to_atom[op]
atom = Chem.Atom(atom_symbol)
atom_idx = mol.AddAtom(atom)
atom_indices[i] = atom_idx
# Add bonds to the molecule
edge_number = edges.shape[1]
for i in range(edge_number):
start = edges[0, i].item()
end = edges[1, i].item()
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
return mol
def arch_str_to_smiles(self, arch_str):
nodes, edges = self.parse_architecture_string(arch_str)
mol = self.create_molecule_from_graph(nodes, edges)
smiles = Chem.MolToSmiles(mol)
return smiles
def get_train_smiles(self): def get_train_smiles(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") train_smiles = []
test_smiles = []
for graph in self.train_dataset:
# print(f'idx={idx}')
# graph = self.train_dataset[idx]
print(graph.x)
print(graph.edge_index)
print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
mol = self.create_molecule_from_graph(graph)
train_smiles.append(Chem.MolToSmiles(mol))
# for idx in self.test_index:
for graph in self.test_dataset:
# graph = self.dataset[idx]
# mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
mol = self.create_molecule_from_graph(graph)
test_smiles.append(Chem.MolToSmiles(mol))
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
# test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
return train_smiles, test_smiles
def get_data_split(self): def get_data_split(self):
raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.") raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
@ -129,161 +214,8 @@ class DataModule(AbstractDataModule):
def test_dataloader(self): def test_dataloader(self):
return self.test_loader return self.test_loader
def graphs_to_json(graphs, filename):
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
source_name = "nas-bench-201"
num_graph = len(graphs)
pt = Chem.GetPeriodicTable()
atom_name_list = []
atom_count_list = []
for i in range(2, 119):
atom_name_list.append(pt.GetElementSymbol(i))
atom_count_list.append(0)
atom_name_list.append('*')
atom_count_list.append(0)
n_atoms_per_mol = [0] * 500
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500
transition_E = np.zeros((118, 118, 5))
n_atom_list = []
n_bond_list = []
# graphs = [(adj_matrix, ops), ...]
for graph in graphs:
ops = graph[1]
adj = graph[0]
n_atom = len(ops)
n_bond = len(ops)
n_atom_list.append(n_atom)
n_bond_list.append(n_bond)
n_atoms_per_mol[n_atom] += 1
cur_atom_count_arr = np.zeros(118)
for op in ops:
symbol = op_to_atom[op]
if symbol == 'H':
continue
elif symbol == '*':
atom_count_list[-1] += 1
cur_atom_count_arr[-1] += 1
else:
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
# print('symbol', symbol)
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
try:
valencies[int(pt.GetDefaultValence(symbol))] += 1
except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5))
# print(n_atom)
for i in range(n_atom):
for j in range(n_atom):
if i == j or adj[i][j] == 0:
continue
start_atom, end_atom = i, j
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
continue
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
continue
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
continue
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
bond_index = bonds[ops[end_atom]]
bond_count_list[bond_index] += 2
# print(start_index, end_index, bond_index)
transition_E[start_index, end_index, bond_index] += 2
transition_E[end_index, start_index, bond_index] += 2
transition_E_temp[start_index, end_index, bond_index] += 2
transition_E_temp[end_index, start_index, bond_index] += 2
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
print(bond_count_list)
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
# print(f'cur_tot_bond={cur_tot_bond}')
# find non-zero element in cur_tot_bond
# for i in range(118):
# for j in range(118):
# if cur_tot_bond[i][j] != 0:
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
# find non-zero element in transition_E
# for i in range(118):
# for j in range(118):
# if transition_E[i][j][0] != 0:
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
print('processed meta info: ------', filename, '------')
print('len atom_count_list', len(atom_count_list))
print('len atom_name_list', len(atom_name_list))
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
active_atoms = active_atoms.tolist()
atom_count_list = atom_count_list.tolist()
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
bond_count_list = bond_count_list.tolist()
valencies = np.array(valencies) / np.sum(valencies)
valencies = valencies.tolist()
no_edge = np.sum(transition_E, axis=-1) == 0
for i in range(118):
for j in range(118):
if no_edge[i][j] == False:
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
# print(f'no_edge: {no_edge}')
first_elt = transition_E[:, :, 0]
first_elt[no_edge] = 1
transition_E[:, :, 0] = first_elt
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
# find non-zero element in transition_E again
for i in range(118):
for j in range(118):
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
meta_dict = {
'source': 'nasbench-201',
'num_graph': num_graph,
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
'max_node': max(n_atom_list),
'max_bond': max(n_bond_list),
'atom_type_dist': atom_count_list,
'bond_type_dist': bond_count_list,
'valencies': valencies,
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
'transition_E': transition_E.tolist(),
}
with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
class DataModule_original(AbstractDataModule): class DataModule_original(AbstractDataModule):
def __init__(self, cfg): def __init__(self, cfg):
@ -412,7 +344,7 @@ def graphs_to_json(graphs, filename):
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500 valencies = [0] * 500
transition_E = np.zeros((118, 118, 5)) transition_E = np.zeros((118, 118, 8))
n_atom_list = [] n_atom_list = []
n_bond_list = [] n_bond_list = []
@ -445,7 +377,7 @@ def graphs_to_json(graphs, filename):
valencies[int(pt.GetDefaultValence(symbol))] += 1 valencies[int(pt.GetDefaultValence(symbol))] += 1
except: except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5)) transition_E_temp = np.zeros((118, 118, 8))
# print(n_atom) # print(n_atom)
for i in range(n_atom): for i in range(n_atom):
for j in range(n_atom): for j in range(n_atom):
@ -542,6 +474,102 @@ def graphs_to_json(graphs, filename):
with open(f'{filename}.meta.json', 'w') as f: with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source
self.api = API(source) # Initialize NAS-Bench-201 API
print('API loaded')
super().__init__(root, transform, pre_transform, pre_filter)
print('Dataset initialized')
print(self.processed_paths[0])
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return [] # NAS-Bench-201 data is loaded via the API, no raw files needed
@property
def processed_file_names(self):
return [f'{self.source}.pt']
def process(self):
def parse_architecture_string(arch_str):
stages = arch_str.split('+')
nodes = ['input']
edges = []
for stage in stages:
operations = stage.strip('|').split('|')
for op in operations:
operation, idx = op.split('~')
idx = int(idx)
edges.append((idx, len(nodes))) # Add edge from idx to the new node
nodes.append(operation)
nodes.append('output') # Add the output node
return nodes, edges
def create_graph(nodes, edges):
G = nx.DiGraph()
for i, node in enumerate(nodes):
G.add_node(i, label=node)
G.add_edges_from(edges)
return G
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
nodes, edges = parse_architecture_string(arch_str)
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
assert 0 not in node_labels, f'Invalid node label: {node_labels}'
x = torch.LongTensor(node_labels)
print(f'in initialize Dataset, arch_to_Graph x={x}')
edges_list = [(start, end) for start, end in edges]
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
edge_attr = edge_type.view(-1, 1)
if target3 is not None:
y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
elif target2 is not None:
y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
else:
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data, nodes
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'output': 5,
'none': 6,
'input': 7
}
# Prepare to process NAS-Bench-201 data
data_list = []
len_data = len(self.api) # Number of architectures
with tqdm(total=len_data) as pbar:
for arch_index in range(len_data):
arch_info = self.api.query_meta_info_by_index(arch_index)
arch_str = arch_info.arch_str
sa = np.random.rand() # Placeholder for synthetic accessibility
sc = np.random.rand() # Placeholder for substructure count
target = np.random.rand() # Placeholder for target value
target2 = np.random.rand() # Placeholder for second target value
target3 = np.random.rand() # Placeholder for third target value
data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
data_list.append(data)
pbar.update(1)
torch.save(self.collate(data_list), self.processed_paths[0])
class Dataset_origin(InMemoryDataset): class Dataset_origin(InMemoryDataset):
def __init__(self, source, root, target_prop=None, def __init__(self, source, root, target_prop=None,
@ -671,7 +699,7 @@ class DataInfos(AbstractDatasetInfos):
length = 15625 length = 15625
ops_type = {} ops_type = {}
len_ops = set() len_ops = set()
api = API('../NAS-Bench-201-v1_0-e61699.pth') api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
for i in range(length): for i in range(length):
arch_info = api.query_meta_info_by_index(i) arch_info = api.query_meta_info_by_index(i)
nodes, edges = parse_architecture_string(arch_info.arch_str) nodes, edges = parse_architecture_string(arch_info.arch_str)

View File

@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils import utils
class Graph_DiT(pl.LightningModule): class Graph_DiT(pl.LightningModule):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
def __init__(self, cfg, dataset_infos, visualization_tools):
super().__init__() super().__init__()
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.test_only = cfg.general.test_only self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
self.test_E_logp = SumExceptBatchMetric() self.test_E_logp = SumExceptBatchMetric()
self.test_y_collection = [] self.test_y_collection = []
self.train_metrics = train_metrics # self.train_metrics = train_metrics
self.sampling_metrics = sampling_metrics # self.sampling_metrics = sampling_metrics
self.visualization_tools = visualization_tools self.visualization_tools = visualization_tools
self.max_n_nodes = dataset_infos.max_n_nodes self.max_n_nodes = dataset_infos.max_n_nodes
@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
self.val_E_kl.reset() self.val_E_kl.reset()
self.val_X_logp.reset() self.val_X_logp.reset()
self.val_E_logp.reset() self.val_E_logp.reset()
self.sampling_metrics.reset() # self.sampling_metrics.reset()
self.val_y_collection = [] self.val_y_collection = []
@torch.no_grad() @torch.no_grad()
@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
samples_left_to_generate -= to_generate samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save chains_left_to_save -= chains_save
print(f"Computing sampling metrics", ' ...') # print(f"Computing sampling metrics", ' ...')
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
current_path = os.getcwd() current_path = os.getcwd()
result_path = os.path.join(current_path, result_path = os.path.join(current_path,
f'graphs/{self.name}/epoch{self.current_epoch}_b0/') f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
self.sampling_metrics.reset() # self.sampling_metrics.reset()
def on_test_epoch_start(self) -> None: def on_test_epoch_start(self) -> None:
print("Starting test...") print("Starting test...")

View File

@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
# Fetch path to this file to get base path # Fetch path to this file to get base path
current_path = os.path.dirname(os.path.realpath(__file__)) current_path = os.path.dirname(os.path.realpath(__file__))
root_dir = current_path.split("outputs")[0] root_dir = current_path.split("outputs")[0]
resume_path = os.path.join(root_dir, cfg.general.resume) resume_path = os.path.join(root_dir, cfg.general.resume)
if cfg.model.type == "discrete": if cfg.model.type == "discrete":
@ -80,21 +79,21 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg) datamodule = dataset.DataModule(cfg)
datamodule.prepare_data() datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
train_smiles, reference_smiles = datamodule.get_train_smiles() # train_smiles, reference_smiles = datamodule.get_train_smiles()
# get input output dimensions # get input output dimensions
dataset_infos.compute_input_output_dims(datamodule=datamodule) dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
sampling_metrics = SamplingMolecularMetrics( # sampling_metrics = SamplingMolecularMetrics(
dataset_infos, train_smiles, reference_smiles # dataset_infos, train_smiles, reference_smiles
) # )
visualization_tools = MolecularVisualization(dataset_infos) visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = { model_kwargs = {
"dataset_infos": dataset_infos, "dataset_infos": dataset_infos,
"train_metrics": train_metrics, # "train_metrics": train_metrics,
"sampling_metrics": sampling_metrics, # "sampling_metrics": sampling_metrics,
"visualization_tools": visualization_tools, "visualization_tools": visualization_tools,
} }
@ -110,9 +109,10 @@ def main(cfg: DictConfig):
model = Graph_DiT(cfg=cfg, **model_kwargs) model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer( trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad, gradient_clip_val=cfg.train.clip_grad,
accelerator="gpu" # accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0 # if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu", # else "cpu",
accelerator="cpu",
devices=cfg.general.gpus devices=cfg.general.gpus
if torch.cuda.is_available() and cfg.general.gpus > 0 if torch.cuda.is_available() and cfg.general.gpus > 0
else None, else None,

File diff suppressed because one or more lines are too long