autodl-projects/tests/test_torch_gpu_bugs.py

44 lines
1.4 KiB
Python
Raw Normal View History

2021-03-30 14:05:52 +02:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest ./tests/test_torch_gpu_bugs.py::test_create
#
# CUDA_VISIBLE_DEVICES="" pytest ./tests/test_torch_gpu_bugs.py::test_load
#####################################################
import os, sys, time, torch
import pickle
import tempfile
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from trade_models.quant_transformer import QuantTransformer
def test_create():
"""Test the basic quant-model."""
if not torch.cuda.is_available():
return
quant_model = QuantTransformer(GPU=0)
temp_dir = lib_dir / ".." / "tests" / ".pytest_cache"
temp_dir.mkdir(parents=True, exist_ok=True)
temp_file = temp_dir / "quant-model.pkl"
with temp_file.open("wb") as f:
# quant_model.to(None)
quant_model.to("cpu")
# del quant_model.model
# del quant_model.train_optimizer
pickle.dump(quant_model, f)
print("save into {:}".format(temp_file))
def test_load():
temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl"
with temp_file.open("rb") as f:
model = pickle.load(f)
print(model.model)
print(model.train_optimizer)