From 73324083ce6a2562a9731f23572f7f54d108b488 Mon Sep 17 00:00:00 2001 From: mhz Date: Wed, 3 Jul 2024 15:25:46 +0200 Subject: [PATCH] update the gpu id --- configs/config.yaml | 5 +++-- graph_dit/main.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 234f679..881f765 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -2,6 +2,7 @@ general: name: 'graph_dit' wandb: 'disabled' gpus: 1 + gpu_number: 3 resume: null test_only: null sample_every_val: 2500 @@ -10,7 +11,7 @@ general: chains_to_save: 1 log_every_steps: 50 number_chain_steps: 8 - final_model_samples_to_generate: 10000 + final_model_samples_to_generate: 100 final_model_samples_to_save: 20 final_model_chains_to_save: 1 enable_progress_bar: False @@ -30,7 +31,7 @@ model: lambda_train: [1, 10] # node and edge training weight ensure_connected: True train: - n_epochs: 10000 + n_epochs: 5000 batch_size: 1200 lr: 0.0002 clip_grad: null diff --git a/graph_dit/main.py b/graph_dit/main.py index 684cb8f..f3d89e5 100644 --- a/graph_dit/main.py +++ b/graph_dit/main.py @@ -175,6 +175,7 @@ def test(cfg: DictConfig): elif cfg.general.resume is not None: cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) + # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, @@ -182,7 +183,7 @@ def test(cfg: DictConfig): accelerator="gpu" if torch.cuda.is_available() and cfg.general.gpus > 0 else "cpu", - devices=cfg.general.gpus + devices=[cfg.general.gpu_number] if torch.cuda.is_available() and cfg.general.gpus > 0 else None, max_epochs=cfg.train.n_epochs,