update the gpu id
This commit is contained in:
		| @@ -175,6 +175,7 @@ def test(cfg: DictConfig): | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
| @@ -182,7 +183,7 @@ def test(cfg: DictConfig): | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         devices=[cfg.general.gpu_number] | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user