preprocess aircraft dataset to get the statistics. which can be used in swap-nas
This commit is contained in:
		
							
								
								
									
										61
									
								
								calculate_datasets_statistics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								calculate_datasets_statistics.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | # import torch | ||||||
|  | # import torchvision | ||||||
|  | # import torchvision.transforms as transforms | ||||||
|  |  | ||||||
|  | # # 加载CIFAR-10数据集 | ||||||
|  | # transform = transforms.Compose([transforms.ToTensor()]) | ||||||
|  | # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | ||||||
|  | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||||
|  |  | ||||||
|  | # # 将所有数据加载到内存中 | ||||||
|  | # data = next(iter(trainloader)) | ||||||
|  | # images, _ = data | ||||||
|  |  | ||||||
|  | # # 计算每个通道的均值和标准差 | ||||||
|  | # mean = images.mean([0, 2, 3]) | ||||||
|  | # std = images.std([0, 2, 3]) | ||||||
|  |  | ||||||
|  | # print(f'Mean: {mean}') | ||||||
|  | # print(f'Std: {std}') | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torchvision import datasets, transforms | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') | ||||||
|  | parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') | ||||||
|  | parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') | ||||||
|  |  | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | # 设置数据集路径 | ||||||
|  | dataset_path = args.data_path | ||||||
|  | dataset_name = args.dataset | ||||||
|  |  | ||||||
|  | # 设置数据集的transform(这里只使用了ToTensor) | ||||||
|  | transform = transforms.Compose([ | ||||||
|  |     transforms.ToTensor() | ||||||
|  | ]) | ||||||
|  |  | ||||||
|  | # 使用ImageFolder加载数据集 | ||||||
|  | dataset = datasets.ImageFolder(root=dataset_path, transform=transform) | ||||||
|  | dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) | ||||||
|  |  | ||||||
|  | # 初始化变量来累积均值和标准差 | ||||||
|  | mean = torch.zeros(3) | ||||||
|  | std = torch.zeros(3) | ||||||
|  | nb_samples = 0 | ||||||
|  |  | ||||||
|  | for data in dataloader: | ||||||
|  |     batch_samples = data[0].size(0) | ||||||
|  |     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||||
|  |     mean += data.mean(2).sum(0) | ||||||
|  |     std += data.std(2).sum(0) | ||||||
|  |     nb_samples += batch_samples | ||||||
|  |  | ||||||
|  | mean /= nb_samples | ||||||
|  | std /= nb_samples | ||||||
|  |  | ||||||
|  | print(f'Mean: {mean}') | ||||||
|  | print(f'Std: {std}') | ||||||
							
								
								
									
										41
									
								
								preprocess_aircraft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								preprocess_aircraft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | import os | ||||||
|  | import shutil | ||||||
|  |  | ||||||
|  | # 数据集路径 | ||||||
|  | dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | ||||||
|  | output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' | ||||||
|  |  | ||||||
|  | # 类别文件,例如 'images_variant_trainval.txt' | ||||||
|  | labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||||
|  |  | ||||||
|  | # 创建输出文件夹 | ||||||
|  | if not os.path.exists(output_path): | ||||||
|  |     os.makedirs(output_path) | ||||||
|  |  | ||||||
|  | # 读取类别文件 | ||||||
|  | with open(labels_file, 'r') as f: | ||||||
|  |     lines = f.readlines() | ||||||
|  |  | ||||||
|  | count = 0 | ||||||
|  |  | ||||||
|  | for line in lines: | ||||||
|  |     count += 1 | ||||||
|  |     print(f'Processing image {count}/{len(lines)}', end='\r') | ||||||
|  |     parts = line.strip().split(' ') | ||||||
|  |     image_name = parts[0] + '.jpg' | ||||||
|  |     category = '_'.join(parts[1:]).replace('/', '_') | ||||||
|  |  | ||||||
|  |     # 创建类别文件夹 | ||||||
|  |     category_path = os.path.join(output_path, category) | ||||||
|  |     if not os.path.exists(category_path): | ||||||
|  |         os.makedirs(category_path) | ||||||
|  |  | ||||||
|  |     # 移动图像到对应类别文件夹 | ||||||
|  |     src = os.path.join(dataset_path, image_name) | ||||||
|  |     dst = os.path.join(category_path, image_name) | ||||||
|  |     if os.path.exists(src): | ||||||
|  |         shutil.move(src, dst) | ||||||
|  |     else: | ||||||
|  |         print(f'Image {image_name} not found!') | ||||||
|  |  | ||||||
|  | print("Images have been sorted into folders by category.") | ||||||
		Reference in New Issue
	
	Block a user