try to run the graph, commented sampling codes
This commit is contained in:
parent
e04ad5fbe7
commit
82299e5213
@ -41,6 +41,6 @@ train:
|
||||
check_val_every_n_epoch: 1
|
||||
dataset:
|
||||
datadir: 'data/'
|
||||
task_name: null
|
||||
guidance_target: null
|
||||
task_name: 'nasbench-201'
|
||||
guidance_target: 'nasbench-201'
|
||||
pin_memory: False
|
||||
|
@ -116,7 +116,7 @@ class AbstractDatasetInfos:
|
||||
def compute_input_output_dims(self, datamodule):
|
||||
example_batch = datamodule.example_batch()
|
||||
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=118).float()[:, self.active_index]
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=5).float()
|
||||
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=10).float()
|
||||
|
||||
self.input_dims = {'X': example_batch_x.size(1),
|
||||
'E': example_batch_edge_attr.size(1),
|
||||
|
@ -81,6 +81,7 @@ class DataModule(AbstractDataModule):
|
||||
|
||||
train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
|
||||
self.train_dataset = train_dataset
|
||||
self.test_dataset = test_dataset
|
||||
print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||
print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
|
||||
@ -93,8 +94,9 @@ class DataModule(AbstractDataModule):
|
||||
self.training_iterations = training_iterations
|
||||
|
||||
def random_data_split(self, dataset):
|
||||
nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
|
||||
labeled_len = len(dataset) - nan_count
|
||||
# nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
|
||||
# labeled_len = len(dataset) - nan_count
|
||||
labeled_len = len(dataset)
|
||||
full_idx = list(range(labeled_len))
|
||||
train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
|
||||
train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
|
||||
@ -115,7 +117,7 @@ class DataModule(AbstractDataModule):
|
||||
print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
|
||||
return train_index, val_index, test_index, []
|
||||
|
||||
def parse_architecture_string(arch_str):
|
||||
def parse_architecture_string(self, arch_str):
|
||||
stages = arch_str.split('+')
|
||||
nodes = ['input']
|
||||
edges = []
|
||||
@ -130,19 +132,39 @@ class DataModule(AbstractDataModule):
|
||||
nodes.append('output') # Add the output node
|
||||
return nodes, edges
|
||||
|
||||
def create_molecule_from_graph(nodes, edges):
|
||||
# def create_molecule_from_graph(nodes, edges):
|
||||
def create_molecule_from_graph(self, graph):
|
||||
nodes = graph.x
|
||||
edges = graph.edge_index
|
||||
mol = Chem.RWMol() # RWMol allows for building the molecule step by step
|
||||
atom_indices = {}
|
||||
|
||||
num_to_op = {
|
||||
1 :'nor_conv_1x1',
|
||||
2 :'nor_conv_3x3',
|
||||
3 :'avg_pool_3x3',
|
||||
4 :'skip_connect',
|
||||
5 :'output',
|
||||
6 :'none',
|
||||
7 :'input'
|
||||
}
|
||||
|
||||
# Extract node operations from the data object
|
||||
|
||||
# Add atoms to the molecule
|
||||
for i, node in enumerate(nodes):
|
||||
atom_symbol = op_to_atom[node]
|
||||
for i, op_tensor in enumerate(nodes):
|
||||
op = op_tensor.item()
|
||||
if op == 0: continue
|
||||
op = num_to_op[op]
|
||||
atom_symbol = op_to_atom[op]
|
||||
atom = Chem.Atom(atom_symbol)
|
||||
atom_idx = mol.AddAtom(atom)
|
||||
atom_indices[i] = atom_idx
|
||||
|
||||
# Add bonds to the molecule
|
||||
for start, end in edges:
|
||||
edge_number = edges.shape[1]
|
||||
for i in range(edge_number):
|
||||
start = edges[0, i].item()
|
||||
end = edges[1, i].item()
|
||||
mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
|
||||
|
||||
return mol
|
||||
@ -154,30 +176,23 @@ class DataModule(AbstractDataModule):
|
||||
return smiles
|
||||
|
||||
def get_train_smiles(self):
|
||||
# raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")
|
||||
# train_arch_strs = []
|
||||
# test_arch_strs = []
|
||||
|
||||
# for idx in self.train_index:
|
||||
# arch_info = self.train_dataset[idx]
|
||||
# arch_str = arch_info.arch_str
|
||||
# train_arch_strs.append(arch_str)
|
||||
# for idx in self.test_index:
|
||||
# arch_info = self.train_dataset[idx]
|
||||
# arch_str = arch_info.arch_str
|
||||
# test_arch_strs.append(arch_str)
|
||||
|
||||
train_smiles = []
|
||||
test_smiles = []
|
||||
|
||||
for idx in self.train_index:
|
||||
graph = self.train_dataset[idx]
|
||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||
for graph in self.train_dataset:
|
||||
# print(f'idx={idx}')
|
||||
# graph = self.train_dataset[idx]
|
||||
print(graph.x)
|
||||
print(graph.edge_index)
|
||||
print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
|
||||
mol = self.create_molecule_from_graph(graph)
|
||||
train_smiles.append(Chem.MolToSmiles(mol))
|
||||
|
||||
for idx in self.test_index:
|
||||
graph = self.train_dataset[idx]
|
||||
mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||
# for idx in self.test_index:
|
||||
for graph in self.test_dataset:
|
||||
# graph = self.dataset[idx]
|
||||
# mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
|
||||
mol = self.create_molecule_from_graph(graph)
|
||||
test_smiles.append(Chem.MolToSmiles(mol))
|
||||
|
||||
# train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
|
||||
@ -199,161 +214,8 @@ class DataModule(AbstractDataModule):
|
||||
def test_dataloader(self):
|
||||
return self.test_loader
|
||||
|
||||
def graphs_to_json(graphs, filename):
|
||||
bonds = {
|
||||
'nor_conv_1x1': 1,
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'input': 7,
|
||||
'output': 5,
|
||||
'none': 6
|
||||
}
|
||||
|
||||
source_name = "nas-bench-201"
|
||||
num_graph = len(graphs)
|
||||
pt = Chem.GetPeriodicTable()
|
||||
atom_name_list = []
|
||||
atom_count_list = []
|
||||
for i in range(2, 119):
|
||||
atom_name_list.append(pt.GetElementSymbol(i))
|
||||
atom_count_list.append(0)
|
||||
atom_name_list.append('*')
|
||||
atom_count_list.append(0)
|
||||
n_atoms_per_mol = [0] * 500
|
||||
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
valencies = [0] * 500
|
||||
transition_E = np.zeros((118, 118, 5))
|
||||
|
||||
n_atom_list = []
|
||||
n_bond_list = []
|
||||
# graphs = [(adj_matrix, ops), ...]
|
||||
for graph in graphs:
|
||||
ops = graph[1]
|
||||
adj = graph[0]
|
||||
n_atom = len(ops)
|
||||
n_bond = len(ops)
|
||||
n_atom_list.append(n_atom)
|
||||
n_bond_list.append(n_bond)
|
||||
|
||||
n_atoms_per_mol[n_atom] += 1
|
||||
cur_atom_count_arr = np.zeros(118)
|
||||
|
||||
for op in ops:
|
||||
symbol = op_to_atom[op]
|
||||
if symbol == 'H':
|
||||
continue
|
||||
elif symbol == '*':
|
||||
atom_count_list[-1] += 1
|
||||
cur_atom_count_arr[-1] += 1
|
||||
else:
|
||||
atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
|
||||
cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
|
||||
# print('symbol', symbol)
|
||||
# print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
|
||||
# print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
|
||||
try:
|
||||
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
||||
except:
|
||||
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
||||
transition_E_temp = np.zeros((118, 118, 5))
|
||||
# print(n_atom)
|
||||
for i in range(n_atom):
|
||||
for j in range(n_atom):
|
||||
if i == j or adj[i][j] == 0:
|
||||
continue
|
||||
start_atom, end_atom = i, j
|
||||
if ops[start_atom] == 'input' or ops[end_atom] == 'input':
|
||||
continue
|
||||
if ops[start_atom] == 'output' or ops[end_atom] == 'output':
|
||||
continue
|
||||
if ops[start_atom] == 'none' or ops[end_atom] == 'none':
|
||||
continue
|
||||
|
||||
start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
|
||||
end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
|
||||
bond_index = bonds[ops[end_atom]]
|
||||
bond_count_list[bond_index] += 2
|
||||
|
||||
# print(start_index, end_index, bond_index)
|
||||
|
||||
transition_E[start_index, end_index, bond_index] += 2
|
||||
transition_E[end_index, start_index, bond_index] += 2
|
||||
transition_E_temp[start_index, end_index, bond_index] += 2
|
||||
transition_E_temp[end_index, start_index, bond_index] += 2
|
||||
|
||||
bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
|
||||
print(bond_count_list)
|
||||
cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
|
||||
# print(f'cur_tot_bond={cur_tot_bond}')
|
||||
# find non-zero element in cur_tot_bond
|
||||
# for i in range(118):
|
||||
# for j in range(118):
|
||||
# if cur_tot_bond[i][j] != 0:
|
||||
# print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
|
||||
# n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
||||
cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
|
||||
# print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
|
||||
transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
|
||||
# find non-zero element in transition_E
|
||||
# for i in range(118):
|
||||
# for j in range(118):
|
||||
# if transition_E[i][j][0] != 0:
|
||||
# print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
|
||||
assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
|
||||
|
||||
n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
|
||||
n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]
|
||||
|
||||
atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
|
||||
print('processed meta info: ------', filename, '------')
|
||||
print('len atom_count_list', len(atom_count_list))
|
||||
print('len atom_name_list', len(atom_name_list))
|
||||
active_atoms = np.array(atom_name_list)[atom_count_list > 0]
|
||||
active_atoms = active_atoms.tolist()
|
||||
atom_count_list = atom_count_list.tolist()
|
||||
|
||||
bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
|
||||
bond_count_list = bond_count_list.tolist()
|
||||
valencies = np.array(valencies) / np.sum(valencies)
|
||||
valencies = valencies.tolist()
|
||||
|
||||
no_edge = np.sum(transition_E, axis=-1) == 0
|
||||
for i in range(118):
|
||||
for j in range(118):
|
||||
if no_edge[i][j] == False:
|
||||
print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
|
||||
# print(f'no_edge: {no_edge}')
|
||||
first_elt = transition_E[:, :, 0]
|
||||
first_elt[no_edge] = 1
|
||||
transition_E[:, :, 0] = first_elt
|
||||
|
||||
transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
|
||||
|
||||
# find non-zero element in transition_E again
|
||||
for i in range(118):
|
||||
for j in range(118):
|
||||
if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
|
||||
print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')
|
||||
|
||||
meta_dict = {
|
||||
'source': 'nasbench-201',
|
||||
'num_graph': num_graph,
|
||||
'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
|
||||
'max_node': max(n_atom_list),
|
||||
'max_bond': max(n_bond_list),
|
||||
'atom_type_dist': atom_count_list,
|
||||
'bond_type_dist': bond_count_list,
|
||||
'valencies': valencies,
|
||||
'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
|
||||
'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
|
||||
'transition_E': transition_E.tolist(),
|
||||
}
|
||||
|
||||
with open(f'{filename}.meta.json', 'w') as f:
|
||||
json.dump(meta_dict, f)
|
||||
return meta_dict
|
||||
|
||||
class DataModule_original(AbstractDataModule):
|
||||
def __init__(self, cfg):
|
||||
@ -482,7 +344,7 @@ def graphs_to_json(graphs, filename):
|
||||
bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
bond_type_to_index = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
||||
valencies = [0] * 500
|
||||
transition_E = np.zeros((118, 118, 5))
|
||||
transition_E = np.zeros((118, 118, 8))
|
||||
|
||||
n_atom_list = []
|
||||
n_bond_list = []
|
||||
@ -515,7 +377,7 @@ def graphs_to_json(graphs, filename):
|
||||
valencies[int(pt.GetDefaultValence(symbol))] += 1
|
||||
except:
|
||||
print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
|
||||
transition_E_temp = np.zeros((118, 118, 5))
|
||||
transition_E_temp = np.zeros((118, 118, 8))
|
||||
# print(n_atom)
|
||||
for i in range(n_atom):
|
||||
for j in range(n_atom):
|
||||
@ -612,14 +474,16 @@ def graphs_to_json(graphs, filename):
|
||||
with open(f'{filename}.meta.json', 'w') as f:
|
||||
json.dump(meta_dict, f)
|
||||
return meta_dict
|
||||
|
||||
class Dataset(InMemoryDataset):
|
||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||
self.target_prop = target_prop
|
||||
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.source = source
|
||||
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||
print('API loaded')
|
||||
super().__init__(root, transform, pre_transform, pre_filter)
|
||||
print('Dataset initialized')
|
||||
print(self.processed_paths[0])
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
@property
|
||||
@ -655,8 +519,11 @@ class Dataset(InMemoryDataset):
|
||||
|
||||
def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
|
||||
nodes, edges = parse_architecture_string(arch_str)
|
||||
|
||||
node_labels = [bonds[node] for node in nodes] # Replace with appropriate encoding if necessary
|
||||
assert 0 not in node_labels, f'Invalid node label: {node_labels}'
|
||||
x = torch.LongTensor(node_labels)
|
||||
print(f'in initialize Dataset, arch_to_Graph x={x}')
|
||||
|
||||
edges_list = [(start, end) for start, end in edges]
|
||||
edge_type = [bonds[nodes[end]] for start, end in edges] # Example: using end node type as edge type
|
||||
@ -671,6 +538,7 @@ class Dataset(InMemoryDataset):
|
||||
else:
|
||||
y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)
|
||||
|
||||
print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
|
||||
return data, nodes
|
||||
|
||||
@ -679,9 +547,9 @@ class Dataset(InMemoryDataset):
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'input': 7,
|
||||
'output': 5,
|
||||
'none': 6
|
||||
'none': 6,
|
||||
'input': 7
|
||||
}
|
||||
|
||||
# Prepare to process NAS-Bench-201 data
|
||||
|
@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
||||
import utils
|
||||
|
||||
class Graph_DiT(pl.LightningModule):
|
||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
def __init__(self, cfg, dataset_infos, visualization_tools):
|
||||
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
self.test_only = cfg.general.test_only
|
||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||
|
||||
@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.test_E_logp = SumExceptBatchMetric()
|
||||
self.test_y_collection = []
|
||||
|
||||
self.train_metrics = train_metrics
|
||||
self.sampling_metrics = sampling_metrics
|
||||
# self.train_metrics = train_metrics
|
||||
# self.sampling_metrics = sampling_metrics
|
||||
|
||||
self.visualization_tools = visualization_tools
|
||||
self.max_n_nodes = dataset_infos.max_n_nodes
|
||||
@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.val_E_kl.reset()
|
||||
self.val_X_logp.reset()
|
||||
self.val_E_logp.reset()
|
||||
self.sampling_metrics.reset()
|
||||
# self.sampling_metrics.reset()
|
||||
self.val_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
|
||||
samples_left_to_generate -= to_generate
|
||||
chains_left_to_save -= chains_save
|
||||
|
||||
print(f"Computing sampling metrics", ' ...')
|
||||
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
# print(f"Computing sampling metrics", ' ...')
|
||||
# valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
# print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
|
||||
current_path = os.getcwd()
|
||||
result_path = os.path.join(current_path,
|
||||
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
self.sampling_metrics.reset()
|
||||
# self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
# self.sampling_metrics.reset()
|
||||
|
||||
def on_test_epoch_start(self) -> None:
|
||||
print("Starting test...")
|
||||
|
@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
|
||||
# Fetch path to this file to get base path
|
||||
current_path = os.path.dirname(os.path.realpath(__file__))
|
||||
root_dir = current_path.split("outputs")[0]
|
||||
|
||||
resume_path = os.path.join(root_dir, cfg.general.resume)
|
||||
|
||||
if cfg.model.type == "discrete":
|
||||
@ -80,21 +79,21 @@ def main(cfg: DictConfig):
|
||||
datamodule = dataset.DataModule(cfg)
|
||||
datamodule.prepare_data()
|
||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
||||
train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
|
||||
# get input output dimensions
|
||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
|
||||
sampling_metrics = SamplingMolecularMetrics(
|
||||
dataset_infos, train_smiles, reference_smiles
|
||||
)
|
||||
# sampling_metrics = SamplingMolecularMetrics(
|
||||
# dataset_infos, train_smiles, reference_smiles
|
||||
# )
|
||||
visualization_tools = MolecularVisualization(dataset_infos)
|
||||
|
||||
model_kwargs = {
|
||||
"dataset_infos": dataset_infos,
|
||||
"train_metrics": train_metrics,
|
||||
"sampling_metrics": sampling_metrics,
|
||||
# "train_metrics": train_metrics,
|
||||
# "sampling_metrics": sampling_metrics,
|
||||
"visualization_tools": visualization_tools,
|
||||
}
|
||||
|
||||
@ -110,9 +109,10 @@ def main(cfg: DictConfig):
|
||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=cfg.train.clip_grad,
|
||||
accelerator="gpu"
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else "cpu",
|
||||
# accelerator="gpu"
|
||||
# if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
# else "cpu",
|
||||
accelerator="cpu",
|
||||
devices=cfg.general.gpus
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else None,
|
||||
|
Loading…
Reference in New Issue
Block a user