From 31b8122cc1eec923921d7a0d638ff6907962fe9e Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 19 Mar 2021 03:22:58 -0700 Subject: [PATCH] Add the SuperMLP class --- README.md | 3 +- lib/spaces/basic_space.py | 151 +++++++++++++----- lib/xlayers/super_core.py | 3 +- lib/xlayers/{super_mlp.py => super_linear.py} | 57 +++++-- tests/test_basic_space.py | 8 +- tests/test_super_model.py | 26 +++ 6 files changed, 195 insertions(+), 53 deletions(-) rename lib/xlayers/{super_mlp.py => super_linear.py} (70%) diff --git a/README.md b/README.md index e779fb4..e6d4849 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,8 @@ Some methods use knowledge distillation (KD), which require pre-trained models. Please use ``` -git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git +git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git XAutoDL +git clone --recurse-submodules https://github.com/D-X-Y/AutoDL-Projects.git XAutoDL ``` to download this repo with submodules. diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index a0b6465..fe7ee7d 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -22,19 +22,32 @@ class Space(metaclass=abc.ABCMeta): All search space must inherit from this basic class. """ + def __init__(self): + # used to avoid duplicate sample + self._last_sample = None + self._last_abstract = None + @abc.abstractproperty - def xrepr(self, prefix="") -> Text: + def xrepr(self, depth=0) -> Text: raise NotImplementedError def __repr__(self) -> Text: return self.xrepr() @abc.abstractproperty - def abstract(self) -> "Space": + def abstract(self, reuse_last=False) -> "Space": raise NotImplementedError @abc.abstractmethod - def random(self, recursion=True): + def random(self, recursion=True, reuse_last=False): + raise NotImplementedError + + @abc.abstractmethod + def clean_last_sample(self): + raise NotImplementedError + + @abc.abstractmethod + def clean_last_abstract(self): raise NotImplementedError @abc.abstractproperty @@ -63,6 +76,7 @@ class VirtualNode(Space): """ def __init__(self, id=None, value=None): + super(VirtualNode, self).__init__() self._id = id self._value = value self._attributes = OrderedDict() @@ -82,26 +96,51 @@ class VirtualNode(Space): # raise ValueError("Can not attach a determined value: {:}".format(value)) self._attributes[key] = value - def xrepr(self, prefix=" ") -> Text: + def xrepr(self, depth=0) -> Text: strs = [self.__class__.__name__ + "(value={:}".format(self._value)] for key, value in self._attributes.items(): - strs.append(value.xrepr(prefix + " " + key + " = ")) + strs.append(key + " = " + value.xrepr(depth + 1)) strs.append(")") - return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) + if len(strs) == 2: + return "".join(strs) + else: + space = " " + xstrs = ( + [strs[0]] + + [space * (depth + 1) + x for x in strs[1:-1]] + + [space * depth + strs[-1]] + ) + return ",\n".join(xstrs) - def abstract(self) -> Space: + def abstract(self, reuse_last=False) -> Space: + if reuse_last and self._last_abstract is not None: + return self._last_abstract node = VirtualNode(id(self)) for key, value in self._attributes.items(): if not value.determined: - node.append(value.abstract()) - return node + node.append(value.abstract(reuse_last)) + self._last_abstract = node + return self._last_abstract - def random(self, recursion=True): + def random(self, recursion=True, reuse_last=False): + if reuse_last and self._last_sample is not None: + return self._last_sample node = VirtualNode(None, self._value) for key, value in self._attributes.items(): - node.append(key, value.random(recursion)) + node.append(key, value.random(recursion, reuse_last)) + self._last_sample = node # record the last sample return node + def clean_last_sample(self): + self._last_sample = None + for key, value in self._attributes.items(): + value.clean_last_sample() + + def clean_last_abstract(self): + self._last_abstract = None + for key, value in self._attributes.items(): + value.clean_last_abstract() + def has(self, x) -> bool: for key, value in self._attributes.items(): if value.has(x): @@ -117,7 +156,7 @@ class VirtualNode(Space): @property def determined(self) -> bool: for key, value in self._attributes.items(): - if not value.determined(x): + if not value.determined: return False return True @@ -138,6 +177,7 @@ class Categorical(Space): """ def __init__(self, *data, default: Optional[int] = None): + super(Categorical, self).__init__() self._candidates = [*data] self._default = default assert self._default is None or 0 <= self._default < len( @@ -169,32 +209,54 @@ class Categorical(Space): def __len__(self): return len(self._candidates) - def abstract(self) -> Space: - if self.determined: - return VirtualNode(id(self), self) - # [TO-IMPROVE] - data = [] - for candidate in self.candidates: + def clean_last_sample(self): + self._last_sample = None + for candidate in self._candidates: if isinstance(candidate, Space): - data.append(candidate.abstract()) - else: - data.append(VirtualNode(id(candidate), candidate)) - return Categorical(*data, default=self._default) + candidate.clean_last_sample() - def random(self, recursion=True): + def clean_last_abstract(self): + self._last_abstract = None + for candidate in self._candidates: + if isinstance(candidate, Space): + candidate.clean_last_abstract() + + def abstract(self, reuse_last=False) -> Space: + if reuse_last and self._last_abstract is not None: + return self._last_abstract + if self.determined: + result = VirtualNode(id(self), self) + else: + # [TO-IMPROVE] + data = [] + for candidate in self.candidates: + if isinstance(candidate, Space): + data.append(candidate.abstract()) + else: + data.append(VirtualNode(id(candidate), candidate)) + result = Categorical(*data, default=self._default) + self._last_abstract = result + return self._last_abstract + + def random(self, recursion=True, reuse_last=False): + if reuse_last and self._last_sample is not None: + return self._last_sample sample = random.choice(self._candidates) if recursion and isinstance(sample, Space): - sample = sample.random(recursion) + sample = sample.random(recursion, reuse_last) if isinstance(sample, VirtualNode): - return sample.copy() + sample = sample.copy() else: - return VirtualNode(None, sample) + sample = VirtualNode(None, sample) + self._last_sample = sample + return self._last_sample - def xrepr(self, prefix=""): + def xrepr(self, depth=0): + del depth xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( name=self.__class__.__name__, cs=self._candidates, default=self._default ) - return prefix + xrepr + return xrepr def has(self, x): super().has(x) @@ -213,7 +275,7 @@ class Categorical(Space): if self.default != other.default: return False for index in range(len(self)): - if self.__getitem__[index] != other[index]: + if self.__getitem__(index) != other[index]: return False return True @@ -235,14 +297,15 @@ class Integer(Categorical): default = data.index(default) super(Integer, self).__init__(*data, default=default) - def xrepr(self, prefix=""): + def xrepr(self, depth=0): + del depth xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( name=self.__class__.__name__, lower=self._raw_lower, upper=self._raw_upper, default=self._raw_default, ) - return prefix + xrepr + return xrepr np_float_types = (np.float16, np.float32, np.float64) @@ -269,6 +332,7 @@ class Continuous(Space): log: bool = False, eps: float = _EPS, ): + super(Continuous, self).__init__() self._lower = lower self._upper = upper self._default = default @@ -295,19 +359,26 @@ class Continuous(Space): def eps(self): return self._eps - def abstract(self) -> Space: - return self.copy() + def abstract(self, reuse_last=False) -> Space: + if reuse_last and self._last_abstract is not None: + return self._last_abstract + self._last_abstract = self.copy() + return self._last_abstract - def random(self, recursion=True): + def random(self, recursion=True, reuse_last=False): del recursion + if reuse_last and self._last_sample is not None: + return self._last_sample if self._log_scale: sample = random.uniform(math.log(self._lower), math.log(self._upper)) sample = math.exp(sample) else: sample = random.uniform(self._lower, self._upper) - return VirtualNode(None, sample) + self._last_sample = VirtualNode(None, sample) + return self._last_sample - def xrepr(self, prefix=""): + def xrepr(self, depth=0): + del depth xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( name=self.__class__.__name__, lower=self._lower, @@ -315,7 +386,7 @@ class Continuous(Space): default=self._default, log=self._log_scale, ) - return prefix + xrepr + return xrepr def convert(self, x): if isinstance(x, np_float_types) and x.size == 1: @@ -338,6 +409,12 @@ class Continuous(Space): def determined(self): return abs(self.lower - self.upper) <= self._eps + def clean_last_sample(self): + self._last_sample = None + + def clean_last_abstract(self): + self._last_abstract = None + def __eq__(self, other): if not isinstance(other, Continuous): return False diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 8c7b056..4f6a090 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -3,4 +3,5 @@ ##################################################### from .super_module import SuperRunMode from .super_module import SuperModule -from .super_mlp import SuperLinear +from .super_linear import SuperLinear +from .super_linear import SuperMLP diff --git a/lib/xlayers/super_mlp.py b/lib/xlayers/super_linear.py similarity index 70% rename from lib/xlayers/super_mlp.py rename to lib/xlayers/super_linear.py index d80ceed..30420d3 100644 --- a/lib/xlayers/super_mlp.py +++ b/lib/xlayers/super_linear.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import math -from typing import Optional, Union +from typing import Optional, Union, Callable import spaces from .super_module import SuperModule @@ -57,11 +57,15 @@ class SuperLinear(SuperModule): def abstract_search_space(self): root_node = spaces.VirtualNode(id(self)) if not spaces.is_determined(self._in_features): - root_node.append("_in_features", self._in_features.abstract()) + root_node.append( + "_in_features", self._in_features.abstract(reuse_last=True) + ) if not spaces.is_determined(self._out_features): - root_node.append("_out_features", self._out_features.abstract()) + root_node.append( + "_out_features", self._out_features.abstract(reuse_last=True) + ) if not spaces.is_determined(self._bias): - root_node.append("_bias", self._bias.abstract()) + root_node.append("_bias", self._bias.abstract(reuse_last=True)) return root_node def reset_parameters(self) -> None: @@ -116,24 +120,51 @@ class SuperMLP(SuperModule): def __init__( self, - in_features, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer=nn.GELU, + in_features: IntSpaceType, + hidden_features: IntSpaceType, + out_features: IntSpaceType, + act_layer: Callable[[], nn.Module] = nn.GELU, drop: Optional[float] = None, ): super(SuperMLP, self).__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) + self._in_features = in_features + self._hidden_features = hidden_features + self._out_features = out_features + self._drop_rate = drop + self.fc1 = SuperLinear(in_features, hidden_features) self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = SuperLinear(hidden_features, out_features) self.drop = nn.Dropout(drop or 0.0) - def forward(self, x): + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + space_fc1 = self.fc1.abstract_search_space + space_fc2 = self.fc2.abstract_search_space + if not spaces.is_determined(space_fc1): + root_node.append("fc1", space_fc1) + if not spaces.is_determined(space_fc2): + root_node.append("fc2", space_fc2) + return root_node + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self._unified_forward(x) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return self._unified_forward(x) + + def _unified_forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x + + def extra_repr(self) -> str: + return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( + self._in_features, + self._hidden_features, + self._out_features, + self._drop_rate, + ) diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index 2de430e..713ec2e 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -48,7 +48,7 @@ class TestBasicSpace(unittest.TestCase): space = Continuous(lower, upper, log=False) values = [] for i in range(1000000): - x = space.random().value + x = space.random(reuse_last=False).value self.assertGreaterEqual(x, lower) self.assertGreaterEqual(upper, x) values.append(x) @@ -97,6 +97,12 @@ class TestBasicSpace(unittest.TestCase): self.assertTrue(is_determined(1)) self.assertFalse(is_determined(nested_space)) + def test_duplicate(self): + space = Categorical(1, 2, 3, 4) + x = space.random() + for _ in range(100): + self.assertEqual(x, space.random(reuse_last=True)) + class TestAbstractSpace(unittest.TestCase): """Test the abstract search spaces.""" diff --git a/tests/test_super_model.py b/tests/test_super_model.py index 7df1f4a..dfb75b7 100644 --- a/tests/test_super_model.py +++ b/tests/test_super_model.py @@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase): output_shape = (32, abstract_child["_out_features"].value) outputs = model(inputs) self.assertEqual(tuple(outputs.shape), output_shape) + + def test_super_mlp(self): + hidden_features = spaces.Categorical(12, 24, 36) + out_features = spaces.Categorical(12, 24, 36) + mlp = super_core.SuperMLP(10, hidden_features, out_features) + print(mlp) + self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) + + abstract_space = mlp.abstract_search_space + print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) + self.assertEqual( + abstract_space["fc1"]["_out_features"], + abstract_space["fc2"]["_in_features"], + ) + self.assertTrue( + abstract_space["fc1"]["_out_features"] + is abstract_space["fc2"]["_in_features"] + ) + + abstract_space.clean_last_sample() + abstract_child = abstract_space.random(reuse_last=True) + print("The abstract child program is:\n{:}".format(abstract_child)) + self.assertEqual( + abstract_child["fc1"]["_out_features"].value, + abstract_child["fc2"]["_in_features"].value, + )