add resize to resize the images; cancel the acc; update the folder path
This commit is contained in:
		| @@ -4,7 +4,7 @@ | ||||
|  | ||||
| # # 加载CIFAR-10数据集 | ||||
| # transform = transforms.Compose([transforms.ToTensor()]) | ||||
| # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | ||||
| # trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) | ||||
| # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||
|  | ||||
| # # 将所有数据加载到内存中 | ||||
| @@ -18,6 +18,10 @@ | ||||
| # print(f'Mean: {mean}') | ||||
| # print(f'Std: {std}') | ||||
|  | ||||
| # results: | ||||
| # Mean: tensor([0.4935, 0.4834, 0.4472]) | ||||
| # Std: tensor([0.2476, 0.2446, 0.2626])   | ||||
|  | ||||
| import torch | ||||
| from torchvision import datasets, transforms | ||||
| from torch.utils.data import DataLoader | ||||
| @@ -35,6 +39,7 @@ dataset_name = args.dataset | ||||
|  | ||||
| # 设置数据集的transform(这里只使用了ToTensor) | ||||
| transform = transforms.Compose([ | ||||
|     transforms.Resize((224, 224)), | ||||
|     transforms.ToTensor() | ||||
| ]) | ||||
|  | ||||
| @@ -47,7 +52,10 @@ mean = torch.zeros(3) | ||||
| std = torch.zeros(3) | ||||
| nb_samples = 0 | ||||
|  | ||||
| count = 0 | ||||
| for data in dataloader: | ||||
|     count += 1 | ||||
|     print(f'Processing batch {count}/{len(dataloader)}', end='\r') | ||||
|     batch_samples = data[0].size(0) | ||||
|     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||
|     mean += data.mean(2).sum(0) | ||||
|   | ||||
| @@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup | ||||
| parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | ||||
| parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | ||||
| parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | ||||
| parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate') | ||||
|  | ||||
| args = parser.parse_args() | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|      | ||||
|     device = torch.device(args.device) | ||||
| @@ -58,18 +58,21 @@ if __name__ == "__main__": | ||||
|  | ||||
|     # nasbench_len = 15625 | ||||
|     nasbench_len = 15625 | ||||
|     filename = f'output/swap_results_{args.datasets}.csv' | ||||
|     if args.datasets == 'aircraft': | ||||
|         api_datasets = 'cifar10' | ||||
|      | ||||
|     # for index, i in arch_info.iterrows(): | ||||
|     for ind in range(nasbench_len): | ||||
|     for ind in range(args.start_index,nasbench_len): | ||||
|         # print(f'Evaluating network: {index}') | ||||
|         print(f'Evaluating network: {ind}') | ||||
|  | ||||
|         config = api.get_net_config(ind, args.datasets) | ||||
|         config = api.get_net_config(ind, api_datasets) | ||||
|         network = get_cell_based_tiny_net(config) | ||||
|         # nas_results = api.query_by_index(i, 'cifar10') | ||||
|         # acc = nas_results[111].get_eval('ori-test') | ||||
|         nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) | ||||
|         acc = nas_results['test-accuracy'] | ||||
|         # nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False) | ||||
|         # acc = nas_results['test-accuracy'] | ||||
|         acc = 99 | ||||
|  | ||||
|         # print(type(network)) | ||||
|         start_time = time.time() | ||||
| @@ -98,6 +101,8 @@ if __name__ == "__main__": | ||||
|         print(f'Elapsed time: {end_time - start_time:.2f} seconds') | ||||
|  | ||||
|         results.append([np.mean(swap_score), acc, ind]) | ||||
|         with open(filename, 'a') as f: | ||||
|             f.write(f'{np.mean(swap_score)},{acc},{ind}\n') | ||||
|  | ||||
|     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) | ||||
|     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) | ||||
|   | ||||
| @@ -3,22 +3,29 @@ import shutil | ||||
|  | ||||
| # 数据集路径 | ||||
| dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | ||||
| output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' | ||||
| test_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/test_sorted_images' | ||||
| train_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/train_sorted_images' | ||||
|  | ||||
| # 类别文件,例如 'images_variant_trainval.txt' | ||||
| labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||
| # 有两个文件,一个是训练集和验证集,一个是测试集 | ||||
| test_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||
| train_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_train.txt' | ||||
|  | ||||
| # 创建输出文件夹 | ||||
| if not os.path.exists(output_path): | ||||
|     os.makedirs(output_path) | ||||
| if not os.path.exists(test_output_path): | ||||
|     os.makedirs(test_output_path) | ||||
| if not os.path.exists(train_output_path): | ||||
|     os.makedirs(train_output_path) | ||||
|  | ||||
| # 读取类别文件 | ||||
| with open(labels_file, 'r') as f: | ||||
|     lines = f.readlines() | ||||
| with open(test_labels_file, 'r') as f: | ||||
|     test_lines = f.readlines() | ||||
| with open(train_labels_file, 'r') as f: | ||||
|     train_lines = f.readlines() | ||||
|  | ||||
| count = 0 | ||||
|  | ||||
| for line in lines: | ||||
| def sort_images(lines, output_path): | ||||
|     count = 0 | ||||
|     for line in lines: | ||||
|         count += 1 | ||||
|         print(f'Processing image {count}/{len(lines)}', end='\r') | ||||
|         parts = line.strip().split(' ') | ||||
| @@ -38,4 +45,9 @@ for line in lines: | ||||
|         else: | ||||
|             print(f'Image {image_name} not found!') | ||||
|  | ||||
| print("Sorting test images into folders by category...") | ||||
| sort_images(test_lines, test_output_path) | ||||
| print("Sorting train images into folders by category...") | ||||
| sort_images(train_lines, train_output_path) | ||||
|  | ||||
| print("Images have been sorted into folders by category.") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user