preprocess aircraft dataset to get the statistics. which can be used in swap-nas
This commit is contained in:
parent
a7a6906a6d
commit
33452adc3b
61
calculate_datasets_statistics.py
Normal file
61
calculate_datasets_statistics.py
Normal file
@ -0,0 +1,61 @@
|
||||
# import torch
|
||||
# import torchvision
|
||||
# import torchvision.transforms as transforms
|
||||
|
||||
# # 加载CIFAR-10数据集
|
||||
# transform = transforms.Compose([transforms.ToTensor()])
|
||||
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
||||
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2)
|
||||
|
||||
# # 将所有数据加载到内存中
|
||||
# data = next(iter(trainloader))
|
||||
# images, _ = data
|
||||
|
||||
# # 计算每个通道的均值和标准差
|
||||
# mean = images.mean([0, 2, 3])
|
||||
# std = images.std([0, 2, 3])
|
||||
|
||||
# print(f'Mean: {mean}')
|
||||
# print(f'Std: {std}')
|
||||
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Calculate mean and std of dataset')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name')
|
||||
parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置数据集路径
|
||||
dataset_path = args.data_path
|
||||
dataset_name = args.dataset
|
||||
|
||||
# 设置数据集的transform(这里只使用了ToTensor)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# 使用ImageFolder加载数据集
|
||||
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
|
||||
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
|
||||
|
||||
# 初始化变量来累积均值和标准差
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
nb_samples = 0
|
||||
|
||||
for data in dataloader:
|
||||
batch_samples = data[0].size(0)
|
||||
data = data[0].view(batch_samples, data[0].size(1), -1)
|
||||
mean += data.mean(2).sum(0)
|
||||
std += data.std(2).sum(0)
|
||||
nb_samples += batch_samples
|
||||
|
||||
mean /= nb_samples
|
||||
std /= nb_samples
|
||||
|
||||
print(f'Mean: {mean}')
|
||||
print(f'Std: {std}')
|
41
preprocess_aircraft.py
Normal file
41
preprocess_aircraft.py
Normal file
@ -0,0 +1,41 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
# 数据集路径
|
||||
dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images'
|
||||
output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images'
|
||||
|
||||
# 类别文件,例如 'images_variant_trainval.txt'
|
||||
labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt'
|
||||
|
||||
# 创建输出文件夹
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
# 读取类别文件
|
||||
with open(labels_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
count = 0
|
||||
|
||||
for line in lines:
|
||||
count += 1
|
||||
print(f'Processing image {count}/{len(lines)}', end='\r')
|
||||
parts = line.strip().split(' ')
|
||||
image_name = parts[0] + '.jpg'
|
||||
category = '_'.join(parts[1:]).replace('/', '_')
|
||||
|
||||
# 创建类别文件夹
|
||||
category_path = os.path.join(output_path, category)
|
||||
if not os.path.exists(category_path):
|
||||
os.makedirs(category_path)
|
||||
|
||||
# 移动图像到对应类别文件夹
|
||||
src = os.path.join(dataset_path, image_name)
|
||||
dst = os.path.join(category_path, image_name)
|
||||
if os.path.exists(src):
|
||||
shutil.move(src, dst)
|
||||
else:
|
||||
print(f'Image {image_name} not found!')
|
||||
|
||||
print("Images have been sorted into folders by category.")
|
Loading…
Reference in New Issue
Block a user