30 lines
789 B
Python
30 lines
789 B
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
|
#####################################################
|
|
# pytest tests/test_loader.py -s #
|
|
#####################################################
|
|
import unittest
|
|
import tempfile
|
|
import torch
|
|
|
|
from xautodl.datasets import get_datasets
|
|
|
|
|
|
def test_simple():
|
|
xdir = tempfile.mkdtemp()
|
|
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
|
|
print(train_data)
|
|
print(valid_data)
|
|
|
|
xloader = torch.utils.data.DataLoader(
|
|
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
|
|
)
|
|
print(xloader)
|
|
print(next(iter(xloader)))
|
|
|
|
for i, data in enumerate(xloader):
|
|
print(i)
|
|
|
|
|
|
test_simple()
|