update baseline NAS algos
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| import os, sys, copy, torch, numpy as np | ||||
|  | ||||
| from collections import OrderedDict | ||||
|  | ||||
|  | ||||
| def print_information(information, extra_info=None, show=False): | ||||
| @@ -29,20 +29,26 @@ def print_information(information, extra_info=None, show=False): | ||||
|  | ||||
| class AANASBenchAPI(object): | ||||
|  | ||||
|   def __init__(self, file_path_or_dict): | ||||
|   def __init__(self, file_path_or_dict, verbose=True): | ||||
|     if isinstance(file_path_or_dict, str): | ||||
|       if verbose: print('try to create AA-NAS-Bench api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     else: | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|     self.arch2infos = copy.deepcopy( file_path_or_dict['arch2infos'] ) | ||||
|     self.evaluated_indexes = sorted(list( copy.deepcopy( file_path_or_dict['evaluated_indexes'] ) )) | ||||
|     self.arch2infos = OrderedDict() | ||||
|     for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|       self.arch2infos[xkey] = ArchResults.create_from_state_dict( file_path_or_dict['arch2infos'][xkey] ) | ||||
|     self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes'])) | ||||
|     self.archstr2index = {} | ||||
|     for idx, arch in enumerate(self.meta_archs): | ||||
|       assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) | ||||
|       self.archstr2index[ arch.tostr() ] = idx | ||||
|       #assert arch.tostr() not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch.tostr()]) | ||||
|       assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch]) | ||||
|       self.archstr2index[ arch ] = idx | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     return copy.deepcopy( self.meta_archs[index] ) | ||||
| @@ -54,12 +60,12 @@ class AANASBenchAPI(object): | ||||
|     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) | ||||
|  | ||||
|   def query_index_by_arch(self, arch): | ||||
|     if arch.tostr() in self.archstr2index: | ||||
|       arch_index = self.archstr2index[ arch.tostr() ] | ||||
|     #else: | ||||
|     #  arch_str = Structure.str2fullstructure( arch.tostr() ).tostr() | ||||
|     #  if arch_str in self.archstr2index: | ||||
|     #    arch_index = self.archstr2index[ arch_str ] | ||||
|     if isinstance(arch, str): | ||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||
|       else                         : arch_index = -1 | ||||
|     elif hasattr(arch, 'tostr'): | ||||
|       if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ] | ||||
|       else                                 : arch_index = -1 | ||||
|     else: arch_index = -1 | ||||
|     return arch_index | ||||
|    | ||||
| @@ -80,6 +86,11 @@ class AANASBenchAPI(object): | ||||
|     info = archInfo.query(dataname) | ||||
|     return info | ||||
|  | ||||
|   def query_meta_info_by_index(self, arch_index): | ||||
|     assert arch_index in self.arch2infos, 'arch_index [{:}] does not in arch2info'.format(arch_index) | ||||
|     archInfo = copy.deepcopy( self.arch2infos[ arch_index ] ) | ||||
|     return archInfo | ||||
|  | ||||
|   def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None): | ||||
|     best_index, highest_accuracy = -1, None | ||||
|     for i, idx in enumerate(self.evaluated_indexes): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user