Update weight watcher codes
This commit is contained in:
		| @@ -77,6 +77,7 @@ class NASBench201API(NASBenchMetaAPI): | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|     # This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults | ||||
|     self.arch2infos_dict = OrderedDict() | ||||
|     self._avaliable_hps = set(['12', '200']) | ||||
|     for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|       all_info = file_path_or_dict['arch2infos'][xkey] | ||||
|       hp2archres = OrderedDict() | ||||
|   | ||||
| @@ -75,11 +75,13 @@ class NASBench301API(NASBenchMetaAPI): | ||||
|     self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] ) | ||||
|     # This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults | ||||
|     self.arch2infos_dict = OrderedDict() | ||||
|     self._avaliable_hps = set() | ||||
|     for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())): | ||||
|       all_infos = file_path_or_dict['arch2infos'][xkey] | ||||
|       hp2archres = OrderedDict() | ||||
|       for hp_key, results in all_infos.items(): | ||||
|         hp2archres[hp_key] = ArchResults.create_from_state_dict(results) | ||||
|         self._avaliable_hps.add(hp_key)  # save the avaliable hyper-parameter | ||||
|       self.arch2infos_dict[xkey] = hp2archres | ||||
|     self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes'])) | ||||
|     self.archstr2index = {} | ||||
|   | ||||
| @@ -57,6 +57,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename)) | ||||
|  | ||||
|   @property | ||||
|   def avaliable_hps(self): | ||||
|     return list(copy.deepcopy(self._avaliable_hps)) | ||||
|  | ||||
|   def random(self): | ||||
|     """Return a random index of all architectures.""" | ||||
|     return random.randint(0, len(self.meta_archs)-1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user