update the gpu id
This commit is contained in:
parent
0c3cfb195a
commit
73324083ce
@ -2,6 +2,7 @@ general:
|
|||||||
name: 'graph_dit'
|
name: 'graph_dit'
|
||||||
wandb: 'disabled'
|
wandb: 'disabled'
|
||||||
gpus: 1
|
gpus: 1
|
||||||
|
gpu_number: 3
|
||||||
resume: null
|
resume: null
|
||||||
test_only: null
|
test_only: null
|
||||||
sample_every_val: 2500
|
sample_every_val: 2500
|
||||||
@ -10,7 +11,7 @@ general:
|
|||||||
chains_to_save: 1
|
chains_to_save: 1
|
||||||
log_every_steps: 50
|
log_every_steps: 50
|
||||||
number_chain_steps: 8
|
number_chain_steps: 8
|
||||||
final_model_samples_to_generate: 10000
|
final_model_samples_to_generate: 100
|
||||||
final_model_samples_to_save: 20
|
final_model_samples_to_save: 20
|
||||||
final_model_chains_to_save: 1
|
final_model_chains_to_save: 1
|
||||||
enable_progress_bar: False
|
enable_progress_bar: False
|
||||||
@ -30,7 +31,7 @@ model:
|
|||||||
lambda_train: [1, 10] # node and edge training weight
|
lambda_train: [1, 10] # node and edge training weight
|
||||||
ensure_connected: True
|
ensure_connected: True
|
||||||
train:
|
train:
|
||||||
n_epochs: 10000
|
n_epochs: 5000
|
||||||
batch_size: 1200
|
batch_size: 1200
|
||||||
lr: 0.0002
|
lr: 0.0002
|
||||||
clip_grad: null
|
clip_grad: null
|
||||||
|
@ -175,6 +175,7 @@ def test(cfg: DictConfig):
|
|||||||
elif cfg.general.resume is not None:
|
elif cfg.general.resume is not None:
|
||||||
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
||||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
||||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
gradient_clip_val=cfg.train.clip_grad,
|
gradient_clip_val=cfg.train.clip_grad,
|
||||||
@ -182,7 +183,7 @@ def test(cfg: DictConfig):
|
|||||||
accelerator="gpu"
|
accelerator="gpu"
|
||||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||||
else "cpu",
|
else "cpu",
|
||||||
devices=cfg.general.gpus
|
devices=[cfg.general.gpu_number]
|
||||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||||
else None,
|
else None,
|
||||||
max_epochs=cfg.train.n_epochs,
|
max_epochs=cfg.train.n_epochs,
|
||||||
|
Loading…
Reference in New Issue
Block a user