Refine Transformer
This commit is contained in:
		| @@ -8,6 +8,9 @@ | ||||
| import os, sys, time, torch | ||||
| import pickle | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
|  | ||||
| root_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
|  | ||||
| from xautodl.trade_models.quant_transformer import QuantTransformer | ||||
|  | ||||
| @@ -17,7 +20,7 @@ def test_create(): | ||||
|     if not torch.cuda.is_available(): | ||||
|         return | ||||
|     quant_model = QuantTransformer(GPU=0) | ||||
|     temp_dir = lib_dir / ".." / "tests" / ".pytest_cache" | ||||
|     temp_dir = root_dir / "tests" / ".pytest_cache" | ||||
|     temp_dir.mkdir(parents=True, exist_ok=True) | ||||
|     temp_file = temp_dir / "quant-model.pkl" | ||||
|     with temp_file.open("wb") as f: | ||||
| @@ -30,7 +33,7 @@ def test_create(): | ||||
|  | ||||
|  | ||||
| def test_load(): | ||||
|     temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||
|     temp_file = root_dir / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||
|     with temp_file.open("rb") as f: | ||||
|         model = pickle.load(f) | ||||
|         print(model.model) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user