add resize to resize the images; cancel the acc; update the folder path

This commit is contained in:
Mhrooz 2024-08-31 15:49:42 +02:00
parent 33452adc3b
commit 968157b657
3 changed files with 56 additions and 31 deletions

View File

@ -4,7 +4,7 @@
# # 加载CIFAR-10数据集
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2)
# # 将所有数据加载到内存中
@ -18,6 +18,10 @@
# print(f'Mean: {mean}')
# print(f'Std: {std}')
# results:
# Mean: tensor([0.4935, 0.4834, 0.4472])
# Std: tensor([0.2476, 0.2446, 0.2626])
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
@ -35,6 +39,7 @@ dataset_name = args.dataset
# 设置数据集的transform这里只使用了ToTensor
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
@ -47,7 +52,10 @@ mean = torch.zeros(3)
std = torch.zeros(3)
nb_samples = 0
count = 0
for data in dataloader:
count += 1
print(f'Processing batch {count}/{len(dataloader)}', end='\r')
batch_samples = data[0].size(0)
data = data[0].view(batch_samples, data[0].size(1), -1)
mean += data.mean(2).sum(0)

View File

@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets')
parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate')
args = parser.parse_args()
if __name__ == "__main__":
device = torch.device(args.device)
@ -58,18 +58,21 @@ if __name__ == "__main__":
# nasbench_len = 15625
nasbench_len = 15625
filename = f'output/swap_results_{args.datasets}.csv'
if args.datasets == 'aircraft':
api_datasets = 'cifar10'
# for index, i in arch_info.iterrows():
for ind in range(nasbench_len):
for ind in range(args.start_index,nasbench_len):
# print(f'Evaluating network: {index}')
print(f'Evaluating network: {ind}')
config = api.get_net_config(ind, args.datasets)
config = api.get_net_config(ind, api_datasets)
network = get_cell_based_tiny_net(config)
# nas_results = api.query_by_index(i, 'cifar10')
# acc = nas_results[111].get_eval('ori-test')
nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False)
acc = nas_results['test-accuracy']
# nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False)
# acc = nas_results['test-accuracy']
acc = 99
# print(type(network))
start_time = time.time()
@ -98,6 +101,8 @@ if __name__ == "__main__":
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
results.append([np.mean(swap_score), acc, ind])
with open(filename, 'a') as f:
f.write(f'{np.mean(swap_score)},{acc},{ind}\n')
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)

View File

@ -3,39 +3,51 @@ import shutil
# 数据集路径
dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images'
output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images'
test_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/test_sorted_images'
train_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/train_sorted_images'
# 类别文件,例如 'images_variant_trainval.txt'
labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt'
# 有两个文件,一个是训练集和验证集,一个是测试集
test_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt'
train_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_train.txt'
# 创建输出文件夹
if not os.path.exists(output_path):
os.makedirs(output_path)
if not os.path.exists(test_output_path):
os.makedirs(test_output_path)
if not os.path.exists(train_output_path):
os.makedirs(train_output_path)
# 读取类别文件
with open(labels_file, 'r') as f:
lines = f.readlines()
with open(test_labels_file, 'r') as f:
test_lines = f.readlines()
with open(train_labels_file, 'r') as f:
train_lines = f.readlines()
count = 0
def sort_images(lines, output_path):
count = 0
for line in lines:
count += 1
print(f'Processing image {count}/{len(lines)}', end='\r')
parts = line.strip().split(' ')
image_name = parts[0] + '.jpg'
category = '_'.join(parts[1:]).replace('/', '_')
for line in lines:
count += 1
print(f'Processing image {count}/{len(lines)}', end='\r')
parts = line.strip().split(' ')
image_name = parts[0] + '.jpg'
category = '_'.join(parts[1:]).replace('/', '_')
# 创建类别文件夹
category_path = os.path.join(output_path, category)
if not os.path.exists(category_path):
os.makedirs(category_path)
# 创建类别文件夹
category_path = os.path.join(output_path, category)
if not os.path.exists(category_path):
os.makedirs(category_path)
# 移动图像到对应类别文件夹
src = os.path.join(dataset_path, image_name)
dst = os.path.join(category_path, image_name)
if os.path.exists(src):
shutil.move(src, dst)
else:
print(f'Image {image_name} not found!')
# 移动图像到对应类别文件夹
src = os.path.join(dataset_path, image_name)
dst = os.path.join(category_path, image_name)
if os.path.exists(src):
shutil.move(src, dst)
else:
print(f'Image {image_name} not found!')
print("Sorting test images into folders by category...")
sort_images(test_lines, test_output_path)
print("Sorting train images into folders by category...")
sort_images(train_lines, train_output_path)
print("Images have been sorted into folders by category.")