Add save/load_best for xlayers
This commit is contained in:
		| @@ -2,7 +2,9 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
|  | ||||
| import os | ||||
| import abc | ||||
| import tempfile | ||||
| import warnings | ||||
| from typing import Optional, Union, Callable | ||||
| import torch | ||||
| @@ -16,6 +18,9 @@ from .super_utils import LayerOrder, SuperRunMode | ||||
| from .super_utils import TensorContainer | ||||
| from .super_utils import ShapeContainer | ||||
|  | ||||
| BEST_DIR_KEY = "best_model_dir" | ||||
| BEST_SCORE_KEY = "best_model_score" | ||||
|  | ||||
|  | ||||
| class SuperModule(abc.ABC, nn.Module): | ||||
|     """This class equips the nn.Module class with the ability to apply AutoDL.""" | ||||
| @@ -25,6 +30,7 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|         self._super_run_type = SuperRunMode.Default | ||||
|         self._abstract_child = None | ||||
|         self._verbose = False | ||||
|         self._meta_info = {} | ||||
|  | ||||
|     def set_super_run_type(self, super_run_type): | ||||
|         def _reset_super_run(m): | ||||
| @@ -84,6 +90,34 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|                 total += buf.numel() | ||||
|         return total | ||||
|  | ||||
|     def save_best(self, score): | ||||
|         if BEST_DIR_KEY not in self._meta_info: | ||||
|             tempdir = tempfile.mkdtemp("-xlayers") | ||||
|             self._meta_info[BEST_DIR_KEY] = tempdir | ||||
|         if BEST_SCORE_KEY not in self._meta_info: | ||||
|             self._meta_info[BEST_SCORE_KEY] = None | ||||
|         best_score = self._meta_info[BEST_SCORE_KEY] | ||||
|         if best_score is None or best_score < score: | ||||
|             best_save_path = os.path.join( | ||||
|                 self._meta_info[BEST_DIR_KEY], | ||||
|                 "best-{:}.pth".format(self.__class__.__name__), | ||||
|             ) | ||||
|             self._meta_info[BEST_SCORE_KEY] = score | ||||
|             torch.save(self.state_dict(), best_save_path) | ||||
|             return True, self._meta_info[BEST_SCORE_KEY] | ||||
|         else: | ||||
|             return False, self._meta_info[BEST_SCORE_KEY] | ||||
|  | ||||
|     def load_best(self): | ||||
|         if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info: | ||||
|             raise ValueError("Please call save_best at first") | ||||
|         best_save_path = os.path.join( | ||||
|             self._meta_info[BEST_DIR_KEY], | ||||
|             "best-{:}.pth".format(self.__class__.__name__), | ||||
|         ) | ||||
|         state_dict = torch.load(best_save_path) | ||||
|         self.load_state_dict(state_dict) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|   | ||||
		Reference in New Issue
	
	Block a user