Add the SuperMLP class
This commit is contained in:
parent
51c626c96d
commit
31b8122cc1
@ -96,7 +96,8 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
|
||||
|
||||
Please use
|
||||
```
|
||||
git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git
|
||||
git clone --recurse-submodules git@github.com:D-X-Y/AutoDL-Projects.git XAutoDL
|
||||
git clone --recurse-submodules https://github.com/D-X-Y/AutoDL-Projects.git XAutoDL
|
||||
```
|
||||
to download this repo with submodules.
|
||||
|
||||
|
@ -22,19 +22,32 @@ class Space(metaclass=abc.ABCMeta):
|
||||
All search space must inherit from this basic class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# used to avoid duplicate sample
|
||||
self._last_sample = None
|
||||
self._last_abstract = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def xrepr(self, prefix="") -> Text:
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> Text:
|
||||
return self.xrepr()
|
||||
|
||||
@abc.abstractproperty
|
||||
def abstract(self) -> "Space":
|
||||
def abstract(self, reuse_last=False) -> "Space":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def random(self, recursion=True):
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_sample(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_abstract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@ -63,6 +76,7 @@ class VirtualNode(Space):
|
||||
"""
|
||||
|
||||
def __init__(self, id=None, value=None):
|
||||
super(VirtualNode, self).__init__()
|
||||
self._id = id
|
||||
self._value = value
|
||||
self._attributes = OrderedDict()
|
||||
@ -82,26 +96,51 @@ class VirtualNode(Space):
|
||||
# raise ValueError("Can not attach a determined value: {:}".format(value))
|
||||
self._attributes[key] = value
|
||||
|
||||
def xrepr(self, prefix=" ") -> Text:
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
|
||||
for key, value in self._attributes.items():
|
||||
strs.append(value.xrepr(prefix + " " + key + " = "))
|
||||
strs.append(key + " = " + value.xrepr(depth + 1))
|
||||
strs.append(")")
|
||||
return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs)
|
||||
if len(strs) == 2:
|
||||
return "".join(strs)
|
||||
else:
|
||||
space = " "
|
||||
xstrs = (
|
||||
[strs[0]]
|
||||
+ [space * (depth + 1) + x for x in strs[1:-1]]
|
||||
+ [space * depth + strs[-1]]
|
||||
)
|
||||
return ",\n".join(xstrs)
|
||||
|
||||
def abstract(self) -> Space:
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
node = VirtualNode(id(self))
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined:
|
||||
node.append(value.abstract())
|
||||
return node
|
||||
node.append(value.abstract(reuse_last))
|
||||
self._last_abstract = node
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True):
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
node = VirtualNode(None, self._value)
|
||||
for key, value in self._attributes.items():
|
||||
node.append(key, value.random(recursion))
|
||||
node.append(key, value.random(recursion, reuse_last))
|
||||
self._last_sample = node # record the last sample
|
||||
return node
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_sample()
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_abstract()
|
||||
|
||||
def has(self, x) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if value.has(x):
|
||||
@ -117,7 +156,7 @@ class VirtualNode(Space):
|
||||
@property
|
||||
def determined(self) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined(x):
|
||||
if not value.determined:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -138,6 +177,7 @@ class Categorical(Space):
|
||||
"""
|
||||
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
super(Categorical, self).__init__()
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(
|
||||
@ -169,32 +209,54 @@ class Categorical(Space):
|
||||
def __len__(self):
|
||||
return len(self._candidates)
|
||||
|
||||
def abstract(self) -> Space:
|
||||
if self.determined:
|
||||
return VirtualNode(id(self), self)
|
||||
# [TO-IMPROVE]
|
||||
data = []
|
||||
for candidate in self.candidates:
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
return Categorical(*data, default=self._default)
|
||||
candidate.clean_last_sample()
|
||||
|
||||
def random(self, recursion=True):
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
candidate.clean_last_abstract()
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
if self.determined:
|
||||
result = VirtualNode(id(self), self)
|
||||
else:
|
||||
# [TO-IMPROVE]
|
||||
data = []
|
||||
for candidate in self.candidates:
|
||||
if isinstance(candidate, Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
result = Categorical(*data, default=self._default)
|
||||
self._last_abstract = result
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
sample = sample.random(recursion)
|
||||
sample = sample.random(recursion, reuse_last)
|
||||
if isinstance(sample, VirtualNode):
|
||||
return sample.copy()
|
||||
sample = sample.copy()
|
||||
else:
|
||||
return VirtualNode(None, sample)
|
||||
sample = VirtualNode(None, sample)
|
||||
self._last_sample = sample
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, prefix=""):
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
return prefix + xrepr
|
||||
return xrepr
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
@ -213,7 +275,7 @@ class Categorical(Space):
|
||||
if self.default != other.default:
|
||||
return False
|
||||
for index in range(len(self)):
|
||||
if self.__getitem__[index] != other[index]:
|
||||
if self.__getitem__(index) != other[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -235,14 +297,15 @@ class Integer(Categorical):
|
||||
default = data.index(default)
|
||||
super(Integer, self).__init__(*data, default=default)
|
||||
|
||||
def xrepr(self, prefix=""):
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._raw_lower,
|
||||
upper=self._raw_upper,
|
||||
default=self._raw_default,
|
||||
)
|
||||
return prefix + xrepr
|
||||
return xrepr
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
@ -269,6 +332,7 @@ class Continuous(Space):
|
||||
log: bool = False,
|
||||
eps: float = _EPS,
|
||||
):
|
||||
super(Continuous, self).__init__()
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
@ -295,19 +359,26 @@ class Continuous(Space):
|
||||
def eps(self):
|
||||
return self._eps
|
||||
|
||||
def abstract(self) -> Space:
|
||||
return self.copy()
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
self._last_abstract = self.copy()
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True):
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
del recursion
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
sample = math.exp(sample)
|
||||
else:
|
||||
sample = random.uniform(self._lower, self._upper)
|
||||
return VirtualNode(None, sample)
|
||||
self._last_sample = VirtualNode(None, sample)
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, prefix=""):
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
@ -315,7 +386,7 @@ class Continuous(Space):
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
return prefix + xrepr
|
||||
return xrepr
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
@ -338,6 +409,12 @@ class Continuous(Space):
|
||||
def determined(self):
|
||||
return abs(self.lower - self.upper) <= self._eps
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Continuous):
|
||||
return False
|
||||
|
@ -3,4 +3,5 @@
|
||||
#####################################################
|
||||
from .super_module import SuperRunMode
|
||||
from .super_module import SuperModule
|
||||
from .super_mlp import SuperLinear
|
||||
from .super_linear import SuperLinear
|
||||
from .super_linear import SuperMLP
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
import spaces
|
||||
from .super_module import SuperModule
|
||||
@ -57,11 +57,15 @@ class SuperLinear(SuperModule):
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._in_features):
|
||||
root_node.append("_in_features", self._in_features.abstract())
|
||||
root_node.append(
|
||||
"_in_features", self._in_features.abstract(reuse_last=True)
|
||||
)
|
||||
if not spaces.is_determined(self._out_features):
|
||||
root_node.append("_out_features", self._out_features.abstract())
|
||||
root_node.append(
|
||||
"_out_features", self._out_features.abstract(reuse_last=True)
|
||||
)
|
||||
if not spaces.is_determined(self._bias):
|
||||
root_node.append("_bias", self._bias.abstract())
|
||||
root_node.append("_bias", self._bias.abstract(reuse_last=True))
|
||||
return root_node
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
@ -116,24 +120,51 @@ class SuperMLP(SuperModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer=nn.GELU,
|
||||
in_features: IntSpaceType,
|
||||
hidden_features: IntSpaceType,
|
||||
out_features: IntSpaceType,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
drop: Optional[float] = None,
|
||||
):
|
||||
super(SuperMLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self._in_features = in_features
|
||||
self._hidden_features = hidden_features
|
||||
self._out_features = out_features
|
||||
self._drop_rate = drop
|
||||
self.fc1 = SuperLinear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.fc2 = SuperLinear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop or 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
space_fc1 = self.fc1.abstract_search_space
|
||||
space_fc2 = self.fc2.abstract_search_space
|
||||
if not spaces.is_determined(space_fc1):
|
||||
root_node.append("fc1", space_fc1)
|
||||
if not spaces.is_determined(space_fc2):
|
||||
root_node.append("fc2", space_fc2)
|
||||
return root_node
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self._unified_forward(x)
|
||||
|
||||
def _unified_forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format(
|
||||
self._in_features,
|
||||
self._hidden_features,
|
||||
self._out_features,
|
||||
self._drop_rate,
|
||||
)
|
@ -48,7 +48,7 @@ class TestBasicSpace(unittest.TestCase):
|
||||
space = Continuous(lower, upper, log=False)
|
||||
values = []
|
||||
for i in range(1000000):
|
||||
x = space.random().value
|
||||
x = space.random(reuse_last=False).value
|
||||
self.assertGreaterEqual(x, lower)
|
||||
self.assertGreaterEqual(upper, x)
|
||||
values.append(x)
|
||||
@ -97,6 +97,12 @@ class TestBasicSpace(unittest.TestCase):
|
||||
self.assertTrue(is_determined(1))
|
||||
self.assertFalse(is_determined(nested_space))
|
||||
|
||||
def test_duplicate(self):
|
||||
space = Categorical(1, 2, 3, 4)
|
||||
x = space.random()
|
||||
for _ in range(100):
|
||||
self.assertEqual(x, space.random(reuse_last=True))
|
||||
|
||||
|
||||
class TestAbstractSpace(unittest.TestCase):
|
||||
"""Test the abstract search spaces."""
|
||||
|
@ -48,3 +48,29 @@ class TestSuperLinear(unittest.TestCase):
|
||||
output_shape = (32, abstract_child["_out_features"].value)
|
||||
outputs = model(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||
|
||||
def test_super_mlp(self):
|
||||
hidden_features = spaces.Categorical(12, 24, 36)
|
||||
out_features = spaces.Categorical(12, 24, 36)
|
||||
mlp = super_core.SuperMLP(10, hidden_features, out_features)
|
||||
print(mlp)
|
||||
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
|
||||
|
||||
abstract_space = mlp.abstract_search_space
|
||||
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
|
||||
self.assertEqual(
|
||||
abstract_space["fc1"]["_out_features"],
|
||||
abstract_space["fc2"]["_in_features"],
|
||||
)
|
||||
self.assertTrue(
|
||||
abstract_space["fc1"]["_out_features"]
|
||||
is abstract_space["fc2"]["_in_features"]
|
||||
)
|
||||
|
||||
abstract_space.clean_last_sample()
|
||||
abstract_child = abstract_space.random(reuse_last=True)
|
||||
print("The abstract child program is:\n{:}".format(abstract_child))
|
||||
self.assertEqual(
|
||||
abstract_child["fc1"]["_out_features"].value,
|
||||
abstract_child["fc2"]["_in_features"].value,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user