# import torch # import torchvision # import torchvision.transforms as transforms # # 加载CIFAR-10数据集 # transform = transforms.Compose([transforms.ToTensor()]) # trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) # # 将所有数据加载到内存中 # data = next(iter(trainloader)) # images, _ = data # # 计算每个通道的均值和标准差 # mean = images.mean([0, 2, 3]) # std = images.std([0, 2, 3]) # print(f'Mean: {mean}') # print(f'Std: {std}') # results: # Mean: tensor([0.4935, 0.4834, 0.4472]) # Std: tensor([0.2476, 0.2446, 0.2626]) import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader import argparse parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') args = parser.parse_args() # 设置数据集路径 dataset_path = args.data_path dataset_name = args.dataset # 设置数据集的transform(这里只使用了ToTensor) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # 使用ImageFolder加载数据集 dataset = datasets.ImageFolder(root=dataset_path, transform=transform) dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) # 初始化变量来累积均值和标准差 mean = torch.zeros(3) std = torch.zeros(3) nb_samples = 0 count = 0 for data in dataloader: count += 1 print(f'Processing batch {count}/{len(dataloader)}', end='\r') batch_samples = data[0].size(0) data = data[0].view(batch_samples, data[0].size(1), -1) mean += data.mean(2).sum(0) std += data.std(2).sum(0) nb_samples += batch_samples mean /= nb_samples std /= nb_samples print(f'Mean: {mean}') print(f'Std: {std}')