preprocess aircraft dataset to get the statistics. which can be used in swap-nas
This commit is contained in:
		
							
								
								
									
										61
									
								
								calculate_datasets_statistics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								calculate_datasets_statistics.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| # import torch | ||||
| # import torchvision | ||||
| # import torchvision.transforms as transforms | ||||
|  | ||||
| # # 加载CIFAR-10数据集 | ||||
| # transform = transforms.Compose([transforms.ToTensor()]) | ||||
| # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | ||||
| # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||
|  | ||||
| # # 将所有数据加载到内存中 | ||||
| # data = next(iter(trainloader)) | ||||
| # images, _ = data | ||||
|  | ||||
| # # 计算每个通道的均值和标准差 | ||||
| # mean = images.mean([0, 2, 3]) | ||||
| # std = images.std([0, 2, 3]) | ||||
|  | ||||
| # print(f'Mean: {mean}') | ||||
| # print(f'Std: {std}') | ||||
|  | ||||
| import torch | ||||
| from torchvision import datasets, transforms | ||||
| from torch.utils.data import DataLoader | ||||
| import argparse | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') | ||||
| parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') | ||||
| parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') | ||||
|  | ||||
| args = parser.parse_args() | ||||
|  | ||||
| # 设置数据集路径 | ||||
| dataset_path = args.data_path | ||||
| dataset_name = args.dataset | ||||
|  | ||||
| # 设置数据集的transform(这里只使用了ToTensor) | ||||
| transform = transforms.Compose([ | ||||
|     transforms.ToTensor() | ||||
| ]) | ||||
|  | ||||
| # 使用ImageFolder加载数据集 | ||||
| dataset = datasets.ImageFolder(root=dataset_path, transform=transform) | ||||
| dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) | ||||
|  | ||||
| # 初始化变量来累积均值和标准差 | ||||
| mean = torch.zeros(3) | ||||
| std = torch.zeros(3) | ||||
| nb_samples = 0 | ||||
|  | ||||
| for data in dataloader: | ||||
|     batch_samples = data[0].size(0) | ||||
|     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||
|     mean += data.mean(2).sum(0) | ||||
|     std += data.std(2).sum(0) | ||||
|     nb_samples += batch_samples | ||||
|  | ||||
| mean /= nb_samples | ||||
| std /= nb_samples | ||||
|  | ||||
| print(f'Mean: {mean}') | ||||
| print(f'Std: {std}') | ||||
							
								
								
									
										41
									
								
								preprocess_aircraft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								preprocess_aircraft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| import os | ||||
| import shutil | ||||
|  | ||||
| # 数据集路径 | ||||
| dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | ||||
| output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' | ||||
|  | ||||
| # 类别文件,例如 'images_variant_trainval.txt' | ||||
| labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||
|  | ||||
| # 创建输出文件夹 | ||||
| if not os.path.exists(output_path): | ||||
|     os.makedirs(output_path) | ||||
|  | ||||
| # 读取类别文件 | ||||
| with open(labels_file, 'r') as f: | ||||
|     lines = f.readlines() | ||||
|  | ||||
| count = 0 | ||||
|  | ||||
| for line in lines: | ||||
|     count += 1 | ||||
|     print(f'Processing image {count}/{len(lines)}', end='\r') | ||||
|     parts = line.strip().split(' ') | ||||
|     image_name = parts[0] + '.jpg' | ||||
|     category = '_'.join(parts[1:]).replace('/', '_') | ||||
|  | ||||
|     # 创建类别文件夹 | ||||
|     category_path = os.path.join(output_path, category) | ||||
|     if not os.path.exists(category_path): | ||||
|         os.makedirs(category_path) | ||||
|  | ||||
|     # 移动图像到对应类别文件夹 | ||||
|     src = os.path.join(dataset_path, image_name) | ||||
|     dst = os.path.join(category_path, image_name) | ||||
|     if os.path.exists(src): | ||||
|         shutil.move(src, dst) | ||||
|     else: | ||||
|         print(f'Image {image_name} not found!') | ||||
|  | ||||
| print("Images have been sorted into folders by category.") | ||||
		Reference in New Issue
	
	Block a user