update the gpu id
This commit is contained in:
		| @@ -2,6 +2,7 @@ general: | ||||
|     name: 'graph_dit' | ||||
|     wandb: 'disabled'  | ||||
|     gpus: 1 | ||||
|     gpu_number: 3 | ||||
|     resume: null | ||||
|     test_only: null | ||||
|     sample_every_val: 2500 | ||||
| @@ -10,7 +11,7 @@ general: | ||||
|     chains_to_save: 1 | ||||
|     log_every_steps: 50 | ||||
|     number_chain_steps: 8 | ||||
|     final_model_samples_to_generate: 10000 | ||||
|     final_model_samples_to_generate: 100 | ||||
|     final_model_samples_to_save: 20 | ||||
|     final_model_chains_to_save: 1 | ||||
|     enable_progress_bar: False | ||||
| @@ -30,7 +31,7 @@ model: | ||||
|     lambda_train: [1, 10]  # node and edge training weight  | ||||
|     ensure_connected: True | ||||
| train: | ||||
|     n_epochs: 10000 | ||||
|     n_epochs: 5000 | ||||
|     batch_size: 1200 | ||||
|     lr: 0.0002 | ||||
|     clip_grad: null | ||||
|   | ||||
| @@ -175,6 +175,7 @@ def test(cfg: DictConfig): | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
| @@ -182,7 +183,7 @@ def test(cfg: DictConfig): | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         devices=[cfg.general.gpu_number] | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user