Update find_best API
This commit is contained in:
		
							
								
								
									
										47
									
								
								exps/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								exps/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||
| ############################################################################## | ||||
| # python ./exps/NATS-Bench/main-tss.py --mode meta                           # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, random, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from datasets import get_datasets | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def show_imagenet_16_120(dataset_dir=None): | ||||
|   if dataset_dir is None: | ||||
|     torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch') | ||||
|     dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16') | ||||
|   train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1) | ||||
|   split_info  = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None) | ||||
|   print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10) | ||||
|   print('Training Data: {:}'.format(train_data)) | ||||
|   print('Evaluation Data: {:}'.format(valid_data)) | ||||
|   print('Hold-out training: {:} images.'.format(len(split_info.train))) | ||||
|   print('Hold-out valid   : {:} images.'.format(len(split_info.valid))) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   # show_imagenet_16_120() | ||||
|   api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True) | ||||
|  | ||||
|   valid_acc_12e = [] | ||||
|   test_acc_12e = [] | ||||
|   test_acc_200e = [] | ||||
|   for index in range(10000): | ||||
|     info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12') | ||||
|     valid_acc_12e.append(info['valid-accuracy']) | ||||
|     test_acc_12e.append(info['test-accuracy']) | ||||
|     info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200') | ||||
|     test_acc_200e.append(info['test-accuracy'])  # which I reported. | ||||
| @@ -92,6 +92,10 @@ class ImageNet16(data.Dataset): | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) | ||||
|  | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
| @@ -114,16 +118,16 @@ class ImageNet16(data.Dataset): | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| # | ||||
| """ | ||||
| if __name__ == '__main__': | ||||
|   train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)  | ||||
|   train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)  | ||||
|   valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)  | ||||
|  | ||||
|   print ( len(train) ) | ||||
|   print ( len(valid) ) | ||||
|   image, label = train[111] | ||||
|   trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200) | ||||
|   validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200) | ||||
|   print ( len(trainX) ) | ||||
|   print ( len(validX) ) | ||||
|   #import pdb; pdb.set_trace() | ||||
| """ | ||||
|   | ||||
| @@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     evaluated_indexes = sorted(list(self.evaluated_indexes)) | ||||
|     for arch_index in evaluated_indexes: | ||||
|       self._prepare_info(arch_index) | ||||
|       arch_info = self.arch2infos_dict[arch_index][hp] | ||||
|       info = arch_info.get_compute_costs(dataset)  # the information of costs | ||||
|       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||
| @@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|         print('<' * 40 + '------------' + '<' * 40) | ||||
|     else: | ||||
|       if 0 <= index < len(self.meta_archs): | ||||
|         if index not in self.evaluated_indexes: | ||||
|           self._prepare_info(index) | ||||
|         if index not in self.evaluated_indexes: | ||||
|           print('The {:}-th architecture has not been evaluated ' | ||||
|                 'or not saved.'.format(index)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user