Add the SuperMLP class
This commit is contained in:
		| @@ -96,7 +96,8 @@ Some methods use knowledge distillation (KD), which require pre-trained models. | ||||
|  | ||||
| Please use | ||||
| ``` | ||||
| git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git | ||||
| git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git XAutoDL | ||||
| git clone --recurse-submodules https://github.com/D-X-Y/AutoDL-Projects.git XAutoDL | ||||
| ``` | ||||
| to download this repo with submodules. | ||||
|  | ||||
|   | ||||
| @@ -22,19 +22,32 @@ class Space(metaclass=abc.ABCMeta): | ||||
|     All search space must inherit from this basic class. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         # used to avoid duplicate sample | ||||
|         self._last_sample = None | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, prefix="") -> Text: | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self) -> Text: | ||||
|         return self.xrepr() | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def abstract(self) -> "Space": | ||||
|     def abstract(self, reuse_last=False) -> "Space": | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_sample(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_abstract(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractproperty | ||||
| @@ -63,6 +76,7 @@ class VirtualNode(Space): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, id=None, value=None): | ||||
|         super(VirtualNode, self).__init__() | ||||
|         self._id = id | ||||
|         self._value = value | ||||
|         self._attributes = OrderedDict() | ||||
| @@ -82,26 +96,51 @@ class VirtualNode(Space): | ||||
|         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||
|         self._attributes[key] = value | ||||
|  | ||||
|     def xrepr(self, prefix="  ") -> Text: | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||
|         for key, value in self._attributes.items(): | ||||
|             strs.append(value.xrepr(prefix + "  " + key + " = ")) | ||||
|             strs.append(key + " = " + value.xrepr(depth + 1)) | ||||
|         strs.append(")") | ||||
|         return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) | ||||
|         if len(strs) == 2: | ||||
|             return "".join(strs) | ||||
|         else: | ||||
|             space = "  " | ||||
|             xstrs = ( | ||||
|                 [strs[0]] | ||||
|                 + [space * (depth + 1) + x for x in strs[1:-1]] | ||||
|                 + [space * depth + strs[-1]] | ||||
|             ) | ||||
|             return ",\n".join(xstrs) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         node = VirtualNode(id(self)) | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined: | ||||
|                 node.append(value.abstract()) | ||||
|         return node | ||||
|                 node.append(value.abstract(reuse_last)) | ||||
|         self._last_abstract = node | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         node = VirtualNode(None, self._value) | ||||
|         for key, value in self._attributes.items(): | ||||
|             node.append(key, value.random(recursion)) | ||||
|             node.append(key, value.random(recursion, reuse_last)) | ||||
|         self._last_sample = node  # record the last sample | ||||
|         return node | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_sample() | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_abstract() | ||||
|  | ||||
|     def has(self, x) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if value.has(x): | ||||
| @@ -117,7 +156,7 @@ class VirtualNode(Space): | ||||
|     @property | ||||
|     def determined(self) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined(x): | ||||
|             if not value.determined: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
| @@ -138,6 +177,7 @@ class Categorical(Space): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *data, default: Optional[int] = None): | ||||
|         super(Categorical, self).__init__() | ||||
|         self._candidates = [*data] | ||||
|         self._default = default | ||||
|         assert self._default is None or 0 <= self._default < len( | ||||
| @@ -169,32 +209,54 @@ class Categorical(Space): | ||||
|     def __len__(self): | ||||
|         return len(self._candidates) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         if self.determined: | ||||
|             return VirtualNode(id(self), self) | ||||
|         # [TO-IMPROVE] | ||||
|         data = [] | ||||
|         for candidate in self.candidates: | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 data.append(candidate.abstract()) | ||||
|             else: | ||||
|                 data.append(VirtualNode(id(candidate), candidate)) | ||||
|         return Categorical(*data, default=self._default) | ||||
|                 candidate.clean_last_sample() | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 candidate.clean_last_abstract() | ||||
|  | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         if self.determined: | ||||
|             result = VirtualNode(id(self), self) | ||||
|         else: | ||||
|             # [TO-IMPROVE] | ||||
|             data = [] | ||||
|             for candidate in self.candidates: | ||||
|                 if isinstance(candidate, Space): | ||||
|                     data.append(candidate.abstract()) | ||||
|                 else: | ||||
|                     data.append(VirtualNode(id(candidate), candidate)) | ||||
|             result = Categorical(*data, default=self._default) | ||||
|         self._last_abstract = result | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             sample = sample.random(recursion) | ||||
|             sample = sample.random(recursion, reuse_last) | ||||
|         if isinstance(sample, VirtualNode): | ||||
|             return sample.copy() | ||||
|             sample = sample.copy() | ||||
|         else: | ||||
|             return VirtualNode(None, sample) | ||||
|             sample = VirtualNode(None, sample) | ||||
|         self._last_sample = sample | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|     def has(self, x): | ||||
|         super().has(x) | ||||
| @@ -213,7 +275,7 @@ class Categorical(Space): | ||||
|         if self.default != other.default: | ||||
|             return False | ||||
|         for index in range(len(self)): | ||||
|             if self.__getitem__[index] != other[index]: | ||||
|             if self.__getitem__(index) != other[index]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
| @@ -235,14 +297,15 @@ class Integer(Categorical): | ||||
|             default = data.index(default) | ||||
|         super(Integer, self).__init__(*data, default=default) | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._raw_lower, | ||||
|             upper=self._raw_upper, | ||||
|             default=self._raw_default, | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|  | ||||
| np_float_types = (np.float16, np.float32, np.float64) | ||||
| @@ -269,6 +332,7 @@ class Continuous(Space): | ||||
|         log: bool = False, | ||||
|         eps: float = _EPS, | ||||
|     ): | ||||
|         super(Continuous, self).__init__() | ||||
|         self._lower = lower | ||||
|         self._upper = upper | ||||
|         self._default = default | ||||
| @@ -295,19 +359,26 @@ class Continuous(Space): | ||||
|     def eps(self): | ||||
|         return self._eps | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         return self.copy() | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         self._last_abstract = self.copy() | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         del recursion | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             sample = math.exp(sample) | ||||
|         else: | ||||
|             sample = random.uniform(self._lower, self._upper) | ||||
|         return VirtualNode(None, sample) | ||||
|         self._last_sample = VirtualNode(None, sample) | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._lower, | ||||
| @@ -315,7 +386,7 @@ class Continuous(Space): | ||||
|             default=self._default, | ||||
|             log=self._log_scale, | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|     def convert(self, x): | ||||
|         if isinstance(x, np_float_types) and x.size == 1: | ||||
| @@ -338,6 +409,12 @@ class Continuous(Space): | ||||
|     def determined(self): | ||||
|         return abs(self.lower - self.upper) <= self._eps | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Continuous): | ||||
|             return False | ||||
|   | ||||
| @@ -3,4 +3,5 @@ | ||||
| ##################################################### | ||||
| from .super_module import SuperRunMode | ||||
| from .super_module import SuperModule | ||||
| from .super_mlp import SuperLinear | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLP | ||||
|   | ||||
| @@ -6,7 +6,7 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| import math | ||||
| from typing import Optional, Union | ||||
| from typing import Optional, Union, Callable | ||||
| 
 | ||||
| import spaces | ||||
| from .super_module import SuperModule | ||||
| @@ -57,11 +57,15 @@ class SuperLinear(SuperModule): | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append("_in_features", self._in_features.abstract()) | ||||
|             root_node.append( | ||||
|                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append("_out_features", self._out_features.abstract()) | ||||
|             root_node.append( | ||||
|                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             root_node.append("_bias", self._bias.abstract()) | ||||
|             root_node.append("_bias", self._bias.abstract(reuse_last=True)) | ||||
|         return root_node | ||||
| 
 | ||||
|     def reset_parameters(self) -> None: | ||||
| @@ -116,24 +120,51 @@ class SuperMLP(SuperModule): | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
|         hidden_features: Optional[int] = None, | ||||
|         out_features: Optional[int] = None, | ||||
|         act_layer=nn.GELU, | ||||
|         in_features: IntSpaceType, | ||||
|         hidden_features: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(SuperMLP, self).__init__() | ||||
|         out_features = out_features or in_features | ||||
|         hidden_features = hidden_features or in_features | ||||
|         self.fc1 = nn.Linear(in_features, hidden_features) | ||||
|         self._in_features = in_features | ||||
|         self._hidden_features = hidden_features | ||||
|         self._out_features = out_features | ||||
|         self._drop_rate = drop | ||||
|         self.fc1 = SuperLinear(in_features, hidden_features) | ||||
|         self.act = act_layer() | ||||
|         self.fc2 = nn.Linear(hidden_features, out_features) | ||||
|         self.fc2 = SuperLinear(hidden_features, out_features) | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         space_fc1 = self.fc1.abstract_search_space | ||||
|         space_fc2 = self.fc2.abstract_search_space | ||||
|         if not spaces.is_determined(space_fc1): | ||||
|             root_node.append("fc1", space_fc1) | ||||
|         if not spaces.is_determined(space_fc2): | ||||
|             root_node.append("fc2", space_fc2) | ||||
|         return root_node | ||||
| 
 | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self._unified_forward(x) | ||||
| 
 | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self._unified_forward(x) | ||||
| 
 | ||||
|     def _unified_forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|         x = self.act(x) | ||||
|         x = self.drop(x) | ||||
|         x = self.fc2(x) | ||||
|         x = self.drop(x) | ||||
|         return x | ||||
| 
 | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||
|             self._in_features, | ||||
|             self._hidden_features, | ||||
|             self._out_features, | ||||
|             self._drop_rate, | ||||
|         ) | ||||
| @@ -48,7 +48,7 @@ class TestBasicSpace(unittest.TestCase): | ||||
|         space = Continuous(lower, upper, log=False) | ||||
|         values = [] | ||||
|         for i in range(1000000): | ||||
|             x = space.random().value | ||||
|             x = space.random(reuse_last=False).value | ||||
|             self.assertGreaterEqual(x, lower) | ||||
|             self.assertGreaterEqual(upper, x) | ||||
|             values.append(x) | ||||
| @@ -97,6 +97,12 @@ class TestBasicSpace(unittest.TestCase): | ||||
|         self.assertTrue(is_determined(1)) | ||||
|         self.assertFalse(is_determined(nested_space)) | ||||
|  | ||||
|     def test_duplicate(self): | ||||
|         space = Categorical(1, 2, 3, 4) | ||||
|         x = space.random() | ||||
|         for _ in range(100): | ||||
|             self.assertEqual(x, space.random(reuse_last=True)) | ||||
|  | ||||
|  | ||||
| class TestAbstractSpace(unittest.TestCase): | ||||
|     """Test the abstract search spaces.""" | ||||
|   | ||||
| @@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         output_shape = (32, abstract_child["_out_features"].value) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|  | ||||
|     def test_super_mlp(self): | ||||
|         hidden_features = spaces.Categorical(12, 24, 36) | ||||
|         out_features = spaces.Categorical(12, 24, 36) | ||||
|         mlp = super_core.SuperMLP(10, hidden_features, out_features) | ||||
|         print(mlp) | ||||
|         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||
|  | ||||
|         abstract_space = mlp.abstract_search_space | ||||
|         print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) | ||||
|         self.assertEqual( | ||||
|             abstract_space["fc1"]["_out_features"], | ||||
|             abstract_space["fc2"]["_in_features"], | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             abstract_space["fc1"]["_out_features"] | ||||
|             is abstract_space["fc2"]["_in_features"] | ||||
|         ) | ||||
|  | ||||
|         abstract_space.clean_last_sample() | ||||
|         abstract_child = abstract_space.random(reuse_last=True) | ||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|         self.assertEqual( | ||||
|             abstract_child["fc1"]["_out_features"].value, | ||||
|             abstract_child["fc2"]["_in_features"].value, | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user