70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
# import torch
|
||
# import torchvision
|
||
# import torchvision.transforms as transforms
|
||
|
||
# # 加载CIFAR-10数据集
|
||
# transform = transforms.Compose([transforms.ToTensor()])
|
||
# trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform)
|
||
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2)
|
||
|
||
# # 将所有数据加载到内存中
|
||
# data = next(iter(trainloader))
|
||
# images, _ = data
|
||
|
||
# # 计算每个通道的均值和标准差
|
||
# mean = images.mean([0, 2, 3])
|
||
# std = images.std([0, 2, 3])
|
||
|
||
# print(f'Mean: {mean}')
|
||
# print(f'Std: {std}')
|
||
|
||
# results:
|
||
# Mean: tensor([0.4935, 0.4834, 0.4472])
|
||
# Std: tensor([0.2476, 0.2446, 0.2626])
|
||
|
||
import torch
|
||
from torchvision import datasets, transforms
|
||
from torch.utils.data import DataLoader
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Calculate mean and std of dataset')
|
||
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name')
|
||
parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 设置数据集路径
|
||
dataset_path = args.data_path
|
||
dataset_name = args.dataset
|
||
|
||
# 设置数据集的transform(这里只使用了ToTensor)
|
||
transform = transforms.Compose([
|
||
transforms.Resize((224, 224)),
|
||
transforms.ToTensor()
|
||
])
|
||
|
||
# 使用ImageFolder加载数据集
|
||
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
|
||
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
|
||
|
||
# 初始化变量来累积均值和标准差
|
||
mean = torch.zeros(3)
|
||
std = torch.zeros(3)
|
||
nb_samples = 0
|
||
|
||
count = 0
|
||
for data in dataloader:
|
||
count += 1
|
||
print(f'Processing batch {count}/{len(dataloader)}', end='\r')
|
||
batch_samples = data[0].size(0)
|
||
data = data[0].view(batch_samples, data[0].size(1), -1)
|
||
mean += data.mean(2).sum(0)
|
||
std += data.std(2).sum(0)
|
||
nb_samples += batch_samples
|
||
|
||
mean /= nb_samples
|
||
std /= nb_samples
|
||
|
||
print(f'Mean: {mean}')
|
||
print(f'Std: {std}')
|