diff --git a/correlation/calculate_dataset_statistics.py b/correlation/calculate_dataset_statistics.py new file mode 100755 index 0000000..795d238 --- /dev/null +++ b/correlation/calculate_dataset_statistics.py @@ -0,0 +1,139 @@ +# import torch +# import torchvision +# import torchvision.transforms as transforms + +# # 加载CIFAR-10数据集 +# transform = transforms.Compose([transforms.ToTensor()]) +# trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) +# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) + +# # 将所有数据加载到内存中 +# data = next(iter(trainloader)) +# images, _ = data + +# # 计算每个通道的均值和标准差 +# mean = images.mean([0, 2, 3]) +# std = images.std([0, 2, 3]) + +# print(f'Mean: {mean}') +# print(f'Std: {std}') + +# results: +# Mean: tensor([0.4935, 0.4834, 0.4472]) +# Std: tensor([0.2476, 0.2446, 0.2626]) + +import itertools +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader, TensorDataset +import argparse +import numpy as np +import os + +parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') +parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') +parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') +parser.add_argument('--train_dataset_path', type=str, default='train', help='train dataset path') +parser.add_argument('--test_dataset_path', type=str, default='test', help='test dataset path') + +args = parser.parse_args() + + +# 设置数据集路径 +dataset_path = args.data_path +dataset_name = args.dataset + +if dataset_name == 'cifar10': + transform = transforms.Compose([ + transforms.ToTensor() + ]) +elif dataset_name == 'aircraft' or dataset_name == 'oxford': + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) + + +def to_tensor(pic): + """Convert a PIL Image to a PyTorch tensor. + + Args: + pic (PIL.Image.Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image tensor with shape (C, H, W) and pixel values in range [0.0, 1.0]. + """ + + # Convert the image to a NumPy array + img = np.array(pic, dtype=np.float32) + + # If image has an alpha channel, discard it + if img.shape[-1] == 4: + img = img[:, :, :3] + + # Handle grayscale images (no channels dimension) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + + # Transpose the dimensions from (H, W, C) to (C, H, W) + img = img.transpose((2, 0, 1)) + + # Normalize the pixel values to [0.0, 1.0] + img = img / 255.0 + + # Convert the NumPy array to a PyTorch tensor + tensor = torch.from_numpy(img) + + return tensor + +# 使用ImageFolder加载数据集 +if args.dataset == 'oxford': + train_data = torch.load(os.path.join(dataset_path, args.train_dataset_path)) + test_data = torch.load(os.path.join(dataset_path, args.test_dataset_path)) + + train_tensor_data = [(image, label) for image, label in train_data] + test_tensor_data = [(image, label) for image, label in test_data] + sum_data = train_tensor_data + test_tensor_data + + train_images = [image for image, label in train_tensor_data] + train_labels = torch.tensor([label for image, label in train_tensor_data]) + test_images = [image for image, label in test_tensor_data] + test_labels = torch.tensor([label for image, label in test_tensor_data]) + sum_images = [image for image, label in sum_data] + sum_labels = torch.tensor([label for image, label in sum_data]) + + train_tensors = torch.stack([transform(image) for image in train_images]) + test_tensors = torch.stack([transform(image) for image in test_images]) + sum_tensors = torch.stack([transform(image) for image in sum_images]) + + train_dataset = TensorDataset(train_tensors, train_labels) + test_dataset = TensorDataset(test_tensors, test_labels) + sum_dataset = TensorDataset(sum_tensors, sum_labels) + + train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=4) + test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4) + dataloader = DataLoader(sum_dataset, batch_size=64, shuffle=False, num_workers=4) +else: + dataset = datasets.ImageFolder(root=dataset_path, transform=transform) + dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) + +# 初始化变量来累积均值和标准差 +mean = torch.zeros(3) +std = torch.zeros(3) +nb_samples = 0 + +count = 0 +for data in dataloader: + count += 1 + print(f'Processing batch {count}/{len(dataloader)}', end='\r') + batch_samples = data[0].size(0) + data = data[0].view(batch_samples, data[0].size(1), -1) + mean += data.mean(2).sum(0) + std += data.std(2).sum(0) + nb_samples += batch_samples + +mean /= nb_samples +std /= nb_samples + +print(f'Mean: {mean}') +print(f'Std: {std}')