try to run the graph, commented sampling codes

This commit is contained in:
mhz 2024-06-25 00:09:27 +02:00
parent e04ad5fbe7
commit 82299e5213
5 changed files with 80 additions and 209 deletions

View File

@ -41,6 +41,6 @@ train:
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
dataset: dataset:
datadir: 'data/' datadir: 'data/'
task_name: null task_name: 'nasbench-201'
guidance_target: null guidance_target: 'nasbench-201'
pin_memory: False pin_memory: False

View File

@ -116,7 +116,7 @@ class AbstractDatasetInfos:
def compute_input_output_dims(self, datamodule): def compute_input_output_dims(self, datamodule):
example_batch = datamodule.example_batch() example_batch = datamodule.example_batch()
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index] example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float() example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
self.input_dims = {'X': example_batch_x.size(1), self.input_dims = {'X': example_batch_x.size(1),
'E': example_batch_edge_attr.size(1), 'E': example_batch_edge_attr.size(1),

View File

@ -81,6 +81,7 @@ class DataModule(AbstractDataModule):
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index] train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.test_dataset = test_dataset
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset)) print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
@ -93,8 +94,9 @@ class DataModule(AbstractDataModule):
self.training_iterations = training_iterations self.training_iterations = training_iterations
def random_data_split(self, dataset): def random_data_split(self, dataset):
nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item() # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
labeled_len = len(dataset) - nan_count # labeled_len = len(dataset) - nan_count
labeled_len = len(dataset)
full_idx = list(range(labeled_len)) full_idx = list(range(labeled_len))
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2 train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42) train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
@ -115,7 +117,7 @@ class DataModule(AbstractDataModule):
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index)) print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
return train_index, val_index, test_index, [] return train_index, val_index, test_index, []
def parse_architecture_string(arch_str): def parse_architecture_string(self, arch_str):
stages = arch_str.split('+') stages = arch_str.split('+')
nodes = ['input'] nodes = ['input']
edges = [] edges = []
@ -130,19 +132,39 @@ class DataModule(AbstractDataModule):
nodes.append('output') # Add the output node nodes.append('output') # Add the output node
return nodes, edges return nodes, edges
def create_molecule_from_graph(nodes, edges): # def create_molecule_from_graph(nodes, edges):
def create_molecule_from_graph(self, graph):
nodes = graph.x
edges = graph.edge_index
mol = Chem.RWMol() # RWMol allows for building the molecule step by step mol = Chem.RWMol() # RWMol allows for building the molecule step by step
atom_indices = {} atom_indices = {}
num_to_op = {
1 :'nor_conv_1x1',
2 :'nor_conv_3x3',
3 :'avg_pool_3x3',
4 :'skip_connect',
5 :'output',
6 :'none',
7 :'input'
}
# Extract node operations from the data object
# Add atoms to the molecule # Add atoms to the molecule
for i, node in enumerate(nodes): for i, op_tensor in enumerate(nodes):
atom_symbol = op_to_atom[node] op = op_tensor.item()
if op == 0: continue
op = num_to_op[op]
atom_symbol = op_to_atom[op]
atom = Chem.Atom(atom_symbol) atom = Chem.Atom(atom_symbol)
atom_idx = mol.AddAtom(atom) atom_idx = mol.AddAtom(atom)
atom_indices[i] = atom_idx atom_indices[i] = atom_idx
# Add bonds to the molecule # Add bonds to the molecule
for start, end in edges: edge_number = edges.shape[1]
for i in range(edge_number):
start = edges[0, i].item()
end = edges[1, i].item()
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE) mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
return mol return mol
@ -154,30 +176,23 @@ class DataModule(AbstractDataModule):
return smiles return smiles
def get_train_smiles(self): def get_train_smiles(self):
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
# train_arch_strs = []
# test_arch_strs = []
# for idx in self.train_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# train_arch_strs.append(arch_str)
# for idx in self.test_index:
# arch_info = self.train_dataset[idx]
# arch_str = arch_info.arch_str
# test_arch_strs.append(arch_str)
train_smiles = [] train_smiles = []
test_smiles = [] test_smiles = []
for idx in self.train_index: for graph in self.train_dataset:
graph = self.train_dataset[idx] # print(f'idx={idx}')
mol = self.create_molecule_from_graph(graph.x, graph.edge_index) # graph = self.train_dataset[idx]
print(graph.x)
print(graph.edge_index)
print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
mol = self.create_molecule_from_graph(graph)
train_smiles.append(Chem.MolToSmiles(mol)) train_smiles.append(Chem.MolToSmiles(mol))
for idx in self.test_index: # for idx in self.test_index:
graph = self.train_dataset[idx] for graph in self.test_dataset:
mol = self.create_molecule_from_graph(graph.x, graph.edge_index) # graph = self.dataset[idx]
# mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
mol = self.create_molecule_from_graph(graph)
test_smiles.append(Chem.MolToSmiles(mol)) test_smiles.append(Chem.MolToSmiles(mol))
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs] # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
@ -199,161 +214,8 @@ class DataModule(AbstractDataModule):
def test_dataloader(self): def test_dataloader(self):
return self.test_loader return self.test_loader
def graphs_to_json(graphs, filename):
bonds = {
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'input': 7,
'output': 5,
'none': 6
}
source_name = "nas-bench-201"
num_graph = len(graphs)
pt = Chem.GetPeriodicTable()
atom_name_list = []
atom_count_list = []
for i in range(2, 119):
atom_name_list.append(pt.GetElementSymbol(i))
atom_count_list.append(0)
atom_name_list.append('*')
atom_count_list.append(0)
n_atoms_per_mol = [0] * 500
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500
transition_E = np.zeros((118, 118, 5))
n_atom_list = []
n_bond_list = []
# graphs = [(adj_matrix, ops), ...]
for graph in graphs:
ops = graph[1]
adj = graph[0]
n_atom = len(ops)
n_bond = len(ops)
n_atom_list.append(n_atom)
n_bond_list.append(n_bond)
n_atoms_per_mol[n_atom] += 1
cur_atom_count_arr = np.zeros(118)
for op in ops:
symbol = op_to_atom[op]
if symbol == 'H':
continue
elif symbol == '*':
atom_count_list[-1] += 1
cur_atom_count_arr[-1] += 1
else:
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
# print('symbol', symbol)
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
try:
valencies[int(pt.GetDefaultValence(symbol))] += 1
except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5))
# print(n_atom)
for i in range(n_atom):
for j in range(n_atom):
if i == j or adj[i][j] == 0:
continue
start_atom, end_atom = i, j
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
continue
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
continue
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
continue
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
bond_index = bonds[ops[end_atom]]
bond_count_list[bond_index] += 2
# print(start_index, end_index, bond_index)
transition_E[start_index, end_index, bond_index] += 2
transition_E[end_index, start_index, bond_index] += 2
transition_E_temp[start_index, end_index, bond_index] += 2
transition_E_temp[end_index, start_index, bond_index] += 2
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
print(bond_count_list)
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
# print(f'cur_tot_bond={cur_tot_bond}')
# find non-zero element in cur_tot_bond
# for i in range(118):
# for j in range(118):
# if cur_tot_bond[i][j] != 0:
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
# find non-zero element in transition_E
# for i in range(118):
# for j in range(118):
# if transition_E[i][j][0] != 0:
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
print('processed meta info: ------', filename, '------')
print('len atom_count_list', len(atom_count_list))
print('len atom_name_list', len(atom_name_list))
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
active_atoms = active_atoms.tolist()
atom_count_list = atom_count_list.tolist()
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
bond_count_list = bond_count_list.tolist()
valencies = np.array(valencies) / np.sum(valencies)
valencies = valencies.tolist()
no_edge = np.sum(transition_E, axis=-1) == 0
for i in range(118):
for j in range(118):
if no_edge[i][j] == False:
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
# print(f'no_edge: {no_edge}')
first_elt = transition_E[:, :, 0]
first_elt[no_edge] = 1
transition_E[:, :, 0] = first_elt
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
# find non-zero element in transition_E again
for i in range(118):
for j in range(118):
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
meta_dict = {
'source': 'nasbench-201',
'num_graph': num_graph,
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
'max_node': max(n_atom_list),
'max_bond': max(n_bond_list),
'atom_type_dist': atom_count_list,
'bond_type_dist': bond_count_list,
'valencies': valencies,
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
'transition_E': transition_E.tolist(),
}
with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f)
return meta_dict
class DataModule_original(AbstractDataModule): class DataModule_original(AbstractDataModule):
def __init__(self, cfg): def __init__(self, cfg):
@ -482,7 +344,7 @@ def graphs_to_json(graphs, filename):
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0] bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
valencies = [0] * 500 valencies = [0] * 500
transition_E = np.zeros((118, 118, 5)) transition_E = np.zeros((118, 118, 8))
n_atom_list = [] n_atom_list = []
n_bond_list = [] n_bond_list = []
@ -515,7 +377,7 @@ def graphs_to_json(graphs, filename):
valencies[int(pt.GetDefaultValence(symbol))] += 1 valencies[int(pt.GetDefaultValence(symbol))] += 1
except: except:
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol))) print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
transition_E_temp = np.zeros((118, 118, 5)) transition_E_temp = np.zeros((118, 118, 8))
# print(n_atom) # print(n_atom)
for i in range(n_atom): for i in range(n_atom):
for j in range(n_atom): for j in range(n_atom):
@ -612,14 +474,16 @@ def graphs_to_json(graphs, filename):
with open(f'{filename}.meta.json', 'w') as f: with open(f'{filename}.meta.json', 'w') as f:
json.dump(meta_dict, f) json.dump(meta_dict, f)
return meta_dict return meta_dict
class Dataset(InMemoryDataset): class Dataset(InMemoryDataset):
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None): def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
self.target_prop = target_prop self.target_prop = target_prop
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
self.source = source self.source = source
self.api = API(source) # Initialize NAS-Bench-201 API self.api = API(source) # Initialize NAS-Bench-201 API
print('API loaded')
super().__init__(root, transform, pre_transform, pre_filter) super().__init__(root, transform, pre_transform, pre_filter)
print('Dataset initialized')
print(self.processed_paths[0])
self.data, self.slices = torch.load(self.processed_paths[0]) self.data, self.slices = torch.load(self.processed_paths[0])
@property @property
@ -655,8 +519,11 @@ class Dataset(InMemoryDataset):
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None): def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
nodes, edges = parse_architecture_string(arch_str) nodes, edges = parse_architecture_string(arch_str)
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
assert 0 not in node_labels, f'Invalid node label: {node_labels}'
x = torch.LongTensor(node_labels) x = torch.LongTensor(node_labels)
print(f'in initialize Dataset, arch_to_Graph x={x}')
edges_list = [(start, end) for start, end in edges] edges_list = [(start, end) for start, end in edges]
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
@ -671,6 +538,7 @@ class Dataset(InMemoryDataset):
else: else:
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1) y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
return data, nodes return data, nodes
@ -679,9 +547,9 @@ class Dataset(InMemoryDataset):
'nor_conv_3x3': 2, 'nor_conv_3x3': 2,
'avg_pool_3x3': 3, 'avg_pool_3x3': 3,
'skip_connect': 4, 'skip_connect': 4,
'input': 7,
'output': 5, 'output': 5,
'none': 6 'none': 6,
'input': 7
} }
# Prepare to process NAS-Bench-201 data # Prepare to process NAS-Bench-201 data

View File

@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils import utils
class Graph_DiT(pl.LightningModule): class Graph_DiT(pl.LightningModule):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools): # def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
def __init__(self, cfg, dataset_infos, visualization_tools):
super().__init__() super().__init__()
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics']) # self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.test_only = cfg.general.test_only self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None) self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
self.test_E_logp = SumExceptBatchMetric() self.test_E_logp = SumExceptBatchMetric()
self.test_y_collection = [] self.test_y_collection = []
self.train_metrics = train_metrics # self.train_metrics = train_metrics
self.sampling_metrics = sampling_metrics # self.sampling_metrics = sampling_metrics
self.visualization_tools = visualization_tools self.visualization_tools = visualization_tools
self.max_n_nodes = dataset_infos.max_n_nodes self.max_n_nodes = dataset_infos.max_n_nodes
@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
self.val_E_kl.reset() self.val_E_kl.reset()
self.val_X_logp.reset() self.val_X_logp.reset()
self.val_E_logp.reset() self.val_E_logp.reset()
self.sampling_metrics.reset() # self.sampling_metrics.reset()
self.val_y_collection = [] self.val_y_collection = []
@torch.no_grad() @torch.no_grad()
@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
samples_left_to_generate -= to_generate samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save chains_left_to_save -= chains_save
print(f"Computing sampling metrics", ' ...') # print(f"Computing sampling metrics", ' ...')
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False) # valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n') # print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
current_path = os.getcwd() current_path = os.getcwd()
result_path = os.path.join(current_path, result_path = os.path.join(current_path,
f'graphs/{self.name}/epoch{self.current_epoch}_b0/') f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save) # self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
self.sampling_metrics.reset() # self.sampling_metrics.reset()
def on_test_epoch_start(self) -> None: def on_test_epoch_start(self) -> None:
print("Starting test...") print("Starting test...")

View File

@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
# Fetch path to this file to get base path # Fetch path to this file to get base path
current_path = os.path.dirname(os.path.realpath(__file__)) current_path = os.path.dirname(os.path.realpath(__file__))
root_dir = current_path.split("outputs")[0] root_dir = current_path.split("outputs")[0]
resume_path = os.path.join(root_dir, cfg.general.resume) resume_path = os.path.join(root_dir, cfg.general.resume)
if cfg.model.type == "discrete": if cfg.model.type == "discrete":
@ -80,21 +79,21 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg) datamodule = dataset.DataModule(cfg)
datamodule.prepare_data() datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg) dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
train_smiles, reference_smiles = datamodule.get_train_smiles() # train_smiles, reference_smiles = datamodule.get_train_smiles()
# get input output dimensions # get input output dimensions
dataset_infos.compute_input_output_dims(datamodule=datamodule) dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
sampling_metrics = SamplingMolecularMetrics( # sampling_metrics = SamplingMolecularMetrics(
dataset_infos, train_smiles, reference_smiles # dataset_infos, train_smiles, reference_smiles
) # )
visualization_tools = MolecularVisualization(dataset_infos) visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = { model_kwargs = {
"dataset_infos": dataset_infos, "dataset_infos": dataset_infos,
"train_metrics": train_metrics, # "train_metrics": train_metrics,
"sampling_metrics": sampling_metrics, # "sampling_metrics": sampling_metrics,
"visualization_tools": visualization_tools, "visualization_tools": visualization_tools,
} }
@ -110,9 +109,10 @@ def main(cfg: DictConfig):
model = Graph_DiT(cfg=cfg, **model_kwargs) model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer( trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad, gradient_clip_val=cfg.train.clip_grad,
accelerator="gpu" # accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0 # if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu", # else "cpu",
accelerator="cpu",
devices=cfg.general.gpus devices=cfg.general.gpus
if torch.cuda.is_available() and cfg.general.gpus > 0 if torch.cuda.is_available() and cfg.general.gpus > 0
else None, else None,