add resize to resize the images; cancel the acc; update the folder path
This commit is contained in:
		| @@ -4,7 +4,7 @@ | |||||||
|  |  | ||||||
| # # 加载CIFAR-10数据集 | # # 加载CIFAR-10数据集 | ||||||
| # transform = transforms.Compose([transforms.ToTensor()]) | # transform = transforms.Compose([transforms.ToTensor()]) | ||||||
| # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | # trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) | ||||||
| # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||||
|  |  | ||||||
| # # 将所有数据加载到内存中 | # # 将所有数据加载到内存中 | ||||||
| @@ -18,6 +18,10 @@ | |||||||
| # print(f'Mean: {mean}') | # print(f'Mean: {mean}') | ||||||
| # print(f'Std: {std}') | # print(f'Std: {std}') | ||||||
|  |  | ||||||
|  | # results: | ||||||
|  | # Mean: tensor([0.4935, 0.4834, 0.4472]) | ||||||
|  | # Std: tensor([0.2476, 0.2446, 0.2626])   | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from torchvision import datasets, transforms | from torchvision import datasets, transforms | ||||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||||
| @@ -35,6 +39,7 @@ dataset_name = args.dataset | |||||||
|  |  | ||||||
| # 设置数据集的transform(这里只使用了ToTensor) | # 设置数据集的transform(这里只使用了ToTensor) | ||||||
| transform = transforms.Compose([ | transform = transforms.Compose([ | ||||||
|  |     transforms.Resize((224, 224)), | ||||||
|     transforms.ToTensor() |     transforms.ToTensor() | ||||||
| ]) | ]) | ||||||
|  |  | ||||||
| @@ -47,7 +52,10 @@ mean = torch.zeros(3) | |||||||
| std = torch.zeros(3) | std = torch.zeros(3) | ||||||
| nb_samples = 0 | nb_samples = 0 | ||||||
|  |  | ||||||
|  | count = 0 | ||||||
| for data in dataloader: | for data in dataloader: | ||||||
|  |     count += 1 | ||||||
|  |     print(f'Processing batch {count}/{len(dataloader)}', end='\r') | ||||||
|     batch_samples = data[0].size(0) |     batch_samples = data[0].size(0) | ||||||
|     data = data[0].view(batch_samples, data[0].size(1), -1) |     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||||
|     mean += data.mean(2).sum(0) |     mean += data.mean(2).sum(0) | ||||||
|   | |||||||
| @@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup | |||||||
| parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | ||||||
| parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | ||||||
| parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | ||||||
|  | parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate') | ||||||
|  |  | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|      |      | ||||||
|     device = torch.device(args.device) |     device = torch.device(args.device) | ||||||
| @@ -58,18 +58,21 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     # nasbench_len = 15625 |     # nasbench_len = 15625 | ||||||
|     nasbench_len = 15625 |     nasbench_len = 15625 | ||||||
|  |     filename = f'output/swap_results_{args.datasets}.csv' | ||||||
|  |     if args.datasets == 'aircraft': | ||||||
|  |         api_datasets = 'cifar10' | ||||||
|      |      | ||||||
|     # for index, i in arch_info.iterrows(): |     # for index, i in arch_info.iterrows(): | ||||||
|     for ind in range(nasbench_len): |     for ind in range(args.start_index,nasbench_len): | ||||||
|         # print(f'Evaluating network: {index}') |         # print(f'Evaluating network: {index}') | ||||||
|         print(f'Evaluating network: {ind}') |         print(f'Evaluating network: {ind}') | ||||||
|  |         config = api.get_net_config(ind, api_datasets) | ||||||
|         config = api.get_net_config(ind, args.datasets) |  | ||||||
|         network = get_cell_based_tiny_net(config) |         network = get_cell_based_tiny_net(config) | ||||||
|         # nas_results = api.query_by_index(i, 'cifar10') |         # nas_results = api.query_by_index(i, 'cifar10') | ||||||
|         # acc = nas_results[111].get_eval('ori-test') |         # acc = nas_results[111].get_eval('ori-test') | ||||||
|         nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) |         # nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False) | ||||||
|         acc = nas_results['test-accuracy'] |         # acc = nas_results['test-accuracy'] | ||||||
|  |         acc = 99 | ||||||
|  |  | ||||||
|         # print(type(network)) |         # print(type(network)) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
| @@ -98,6 +101,8 @@ if __name__ == "__main__": | |||||||
|         print(f'Elapsed time: {end_time - start_time:.2f} seconds') |         print(f'Elapsed time: {end_time - start_time:.2f} seconds') | ||||||
|  |  | ||||||
|         results.append([np.mean(swap_score), acc, ind]) |         results.append([np.mean(swap_score), acc, ind]) | ||||||
|  |         with open(filename, 'a') as f: | ||||||
|  |             f.write(f'{np.mean(swap_score)},{acc},{ind}\n') | ||||||
|  |  | ||||||
|     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) |     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) | ||||||
|     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) |     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) | ||||||
|   | |||||||
| @@ -3,21 +3,28 @@ import shutil | |||||||
|  |  | ||||||
| # 数据集路径 | # 数据集路径 | ||||||
| dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | ||||||
| output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' | test_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/test_sorted_images' | ||||||
|  | train_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/train_sorted_images' | ||||||
|  |  | ||||||
| # 类别文件,例如 'images_variant_trainval.txt' | # 类别文件,例如 'images_variant_trainval.txt' | ||||||
| labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | # 有两个文件,一个是训练集和验证集,一个是测试集 | ||||||
|  | test_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||||
|  | train_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_train.txt' | ||||||
|  |  | ||||||
| # 创建输出文件夹 | # 创建输出文件夹 | ||||||
| if not os.path.exists(output_path): | if not os.path.exists(test_output_path): | ||||||
|     os.makedirs(output_path) |     os.makedirs(test_output_path) | ||||||
|  | if not os.path.exists(train_output_path): | ||||||
|  |     os.makedirs(train_output_path) | ||||||
|  |  | ||||||
| # 读取类别文件 | # 读取类别文件 | ||||||
| with open(labels_file, 'r') as f: | with open(test_labels_file, 'r') as f: | ||||||
|     lines = f.readlines() |     test_lines = f.readlines() | ||||||
|  | with open(train_labels_file, 'r') as f: | ||||||
|  |     train_lines = f.readlines() | ||||||
|  |  | ||||||
|  | def sort_images(lines, output_path): | ||||||
|     count = 0 |     count = 0 | ||||||
|  |  | ||||||
|     for line in lines: |     for line in lines: | ||||||
|         count += 1 |         count += 1 | ||||||
|         print(f'Processing image {count}/{len(lines)}', end='\r') |         print(f'Processing image {count}/{len(lines)}', end='\r') | ||||||
| @@ -38,4 +45,9 @@ for line in lines: | |||||||
|         else: |         else: | ||||||
|             print(f'Image {image_name} not found!') |             print(f'Image {image_name} not found!') | ||||||
|  |  | ||||||
|  | print("Sorting test images into folders by category...") | ||||||
|  | sort_images(test_lines, test_output_path) | ||||||
|  | print("Sorting train images into folders by category...") | ||||||
|  | sort_images(train_lines, train_output_path) | ||||||
|  |  | ||||||
| print("Images have been sorted into folders by category.") | print("Images have been sorted into folders by category.") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user