add how to prepare dataset
This commit is contained in:
		
							
								
								
									
										139
									
								
								correlation/calculate_dataset_statistics.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										139
									
								
								correlation/calculate_dataset_statistics.py
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,139 @@
 | 
				
			|||||||
 | 
					# import torch
 | 
				
			||||||
 | 
					# import torchvision
 | 
				
			||||||
 | 
					# import torchvision.transforms as transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # 加载CIFAR-10数据集
 | 
				
			||||||
 | 
					# transform = transforms.Compose([transforms.ToTensor()])
 | 
				
			||||||
 | 
					# trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform)
 | 
				
			||||||
 | 
					# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # 将所有数据加载到内存中
 | 
				
			||||||
 | 
					# data = next(iter(trainloader))
 | 
				
			||||||
 | 
					# images, _ = data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # 计算每个通道的均值和标准差
 | 
				
			||||||
 | 
					# mean = images.mean([0, 2, 3])
 | 
				
			||||||
 | 
					# std = images.std([0, 2, 3])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# print(f'Mean: {mean}')
 | 
				
			||||||
 | 
					# print(f'Std: {std}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# results:
 | 
				
			||||||
 | 
					# Mean: tensor([0.4935, 0.4834, 0.4472])
 | 
				
			||||||
 | 
					# Std: tensor([0.2476, 0.2446, 0.2626])  
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import itertools
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torchvision import datasets, transforms
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader, TensorDataset
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					parser = argparse.ArgumentParser(description='Calculate mean and std of dataset')
 | 
				
			||||||
 | 
					parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name')
 | 
				
			||||||
 | 
					parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder')
 | 
				
			||||||
 | 
					parser.add_argument('--train_dataset_path', type=str, default='train', help='train dataset path')
 | 
				
			||||||
 | 
					parser.add_argument('--test_dataset_path', type=str, default='test', help='test dataset path')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 设置数据集路径
 | 
				
			||||||
 | 
					dataset_path = args.data_path
 | 
				
			||||||
 | 
					dataset_name = args.dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if dataset_name == 'cifar10':
 | 
				
			||||||
 | 
					    transform = transforms.Compose([
 | 
				
			||||||
 | 
					        transforms.ToTensor()
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
 | 
					elif dataset_name == 'aircraft' or dataset_name == 'oxford':
 | 
				
			||||||
 | 
					    transform = transforms.Compose([
 | 
				
			||||||
 | 
					        transforms.Resize((224, 224)),
 | 
				
			||||||
 | 
					        transforms.ToTensor()
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def to_tensor(pic):
 | 
				
			||||||
 | 
					    """Convert a PIL Image to a PyTorch tensor.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        pic (PIL.Image.Image): Image to be converted to tensor.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        Tensor: Converted image tensor with shape (C, H, W) and pixel values in range [0.0, 1.0].
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Convert the image to a NumPy array
 | 
				
			||||||
 | 
					    img = np.array(pic, dtype=np.float32)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # If image has an alpha channel, discard it
 | 
				
			||||||
 | 
					    if img.shape[-1] == 4:
 | 
				
			||||||
 | 
					        img = img[:, :, :3]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Handle grayscale images (no channels dimension)
 | 
				
			||||||
 | 
					    if len(img.shape) == 2:
 | 
				
			||||||
 | 
					        img = np.expand_dims(img, axis=-1)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Transpose the dimensions from (H, W, C) to (C, H, W)
 | 
				
			||||||
 | 
					    img = img.transpose((2, 0, 1))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Normalize the pixel values to [0.0, 1.0]
 | 
				
			||||||
 | 
					    img = img / 255.0
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Convert the NumPy array to a PyTorch tensor
 | 
				
			||||||
 | 
					    tensor = torch.from_numpy(img)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 使用ImageFolder加载数据集
 | 
				
			||||||
 | 
					if args.dataset == 'oxford':
 | 
				
			||||||
 | 
					    train_data = torch.load(os.path.join(dataset_path, args.train_dataset_path))
 | 
				
			||||||
 | 
					    test_data = torch.load(os.path.join(dataset_path, args.test_dataset_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_tensor_data = [(image, label) for image, label in train_data]
 | 
				
			||||||
 | 
					    test_tensor_data = [(image, label) for image, label in test_data]
 | 
				
			||||||
 | 
					    sum_data = train_tensor_data + test_tensor_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_images = [image for image, label in train_tensor_data]
 | 
				
			||||||
 | 
					    train_labels = torch.tensor([label for image, label in train_tensor_data])
 | 
				
			||||||
 | 
					    test_images = [image for image, label in test_tensor_data]
 | 
				
			||||||
 | 
					    test_labels = torch.tensor([label for image, label in test_tensor_data])
 | 
				
			||||||
 | 
					    sum_images = [image for image, label in sum_data]
 | 
				
			||||||
 | 
					    sum_labels = torch.tensor([label for image, label in sum_data])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_tensors = torch.stack([transform(image) for image in train_images])
 | 
				
			||||||
 | 
					    test_tensors = torch.stack([transform(image) for image in test_images])
 | 
				
			||||||
 | 
					    sum_tensors = torch.stack([transform(image) for image in sum_images])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataset = TensorDataset(train_tensors, train_labels)
 | 
				
			||||||
 | 
					    test_dataset = TensorDataset(test_tensors, test_labels)
 | 
				
			||||||
 | 
					    sum_dataset = TensorDataset(sum_tensors, sum_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=4)
 | 
				
			||||||
 | 
					    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
 | 
				
			||||||
 | 
					    dataloader = DataLoader(sum_dataset, batch_size=64, shuffle=False, num_workers=4)
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
 | 
				
			||||||
 | 
					    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 初始化变量来累积均值和标准差
 | 
				
			||||||
 | 
					mean = torch.zeros(3)
 | 
				
			||||||
 | 
					std = torch.zeros(3)
 | 
				
			||||||
 | 
					nb_samples = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					count = 0
 | 
				
			||||||
 | 
					for data in dataloader:
 | 
				
			||||||
 | 
					    count += 1
 | 
				
			||||||
 | 
					    print(f'Processing batch {count}/{len(dataloader)}', end='\r')
 | 
				
			||||||
 | 
					    batch_samples = data[0].size(0)
 | 
				
			||||||
 | 
					    data = data[0].view(batch_samples, data[0].size(1), -1)
 | 
				
			||||||
 | 
					    mean += data.mean(2).sum(0)
 | 
				
			||||||
 | 
					    std += data.std(2).sum(0)
 | 
				
			||||||
 | 
					    nb_samples += batch_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mean /= nb_samples
 | 
				
			||||||
 | 
					std /= nb_samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f'Mean: {mean}')
 | 
				
			||||||
 | 
					print(f'Std: {std}')
 | 
				
			||||||
		Reference in New Issue
	
	Block a user