xautodl/tests/test_loader.py
2021-05-19 08:10:42 +00:00

30 lines
789 B
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_loader.py -s #
#####################################################
import unittest
import tempfile
import torch
from xautodl.datasets import get_datasets
def test_simple():
xdir = tempfile.mkdtemp()
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
print(train_data)
print(valid_data)
xloader = torch.utils.data.DataLoader(
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
)
print(xloader)
print(next(iter(xloader)))
for i, data in enumerate(xloader):
print(i)
test_simple()