Complete Super Linear
This commit is contained in:
parent
9c5ae93494
commit
51c626c96d
@ -12,5 +12,6 @@ from .basic_space import VirtualNode
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
from .basic_op import is_determined
|
||||
from .basic_op import get_determined_value
|
||||
from .basic_op import get_min
|
||||
from .basic_op import get_max
|
||||
|
@ -1,4 +1,5 @@
|
||||
from spaces.basic_space import Space
|
||||
from spaces.basic_space import VirtualNode
|
||||
from spaces.basic_space import Integer
|
||||
from spaces.basic_space import Continuous
|
||||
from spaces.basic_space import Categorical
|
||||
@ -26,6 +27,20 @@ def is_determined(space_or_value):
|
||||
return True
|
||||
|
||||
|
||||
def get_determined_value(space_or_value):
|
||||
if not is_determined(space_or_value):
|
||||
raise ValueError("This input is not determined: {:}".format(space_or_value))
|
||||
if isinstance(space_or_value, Space):
|
||||
if isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
return get_determined_value(space_or_value[0])
|
||||
else: # VirtualNode
|
||||
return space_or_value.value
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_max(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return max(space_or_value.candidates)
|
||||
|
@ -23,7 +23,7 @@ class Space(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractproperty
|
||||
def xrepr(self, indent=0) -> Text:
|
||||
def xrepr(self, prefix="") -> Text:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> Text:
|
||||
@ -67,17 +67,27 @@ class VirtualNode(Space):
|
||||
self._value = value
|
||||
self._attributes = OrderedDict()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def append(self, key, value):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
"Only accept string as a key instead of {:}".format(type(key))
|
||||
)
|
||||
if not isinstance(value, Space):
|
||||
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||
# if value.determined:
|
||||
# raise ValueError("Can not attach a determined value: {:}".format(value))
|
||||
self._attributes[key] = value
|
||||
|
||||
def xrepr(self, indent=0) -> Text:
|
||||
strs = [self.__class__.__name__ + "("]
|
||||
def xrepr(self, prefix=" ") -> Text:
|
||||
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
|
||||
for key, value in self._attributes.items():
|
||||
strs.append(value.xrepr(indent + 2) + ",")
|
||||
strs.append(value.xrepr(prefix + " " + key + " = "))
|
||||
strs.append(")")
|
||||
return "\n".join(strs)
|
||||
return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs)
|
||||
|
||||
def abstract(self) -> Space:
|
||||
node = VirtualNode(id(self))
|
||||
@ -87,7 +97,10 @@ class VirtualNode(Space):
|
||||
return node
|
||||
|
||||
def random(self, recursion=True):
|
||||
raise NotImplementedError
|
||||
node = VirtualNode(None, self._value)
|
||||
for key, value in self._attributes.items():
|
||||
node.append(key, value.random(recursion))
|
||||
return node
|
||||
|
||||
def has(self, x) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
@ -101,6 +114,7 @@ class VirtualNode(Space):
|
||||
def __getitem__(self, key):
|
||||
return self._attributes[key]
|
||||
|
||||
@property
|
||||
def determined(self) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined(x):
|
||||
@ -165,20 +179,22 @@ class Categorical(Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
return Categorical(*data, self._default)
|
||||
return Categorical(*data, default=self._default)
|
||||
|
||||
def random(self, recursion=True):
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
return sample.random(recursion)
|
||||
sample = sample.random(recursion)
|
||||
if isinstance(sample, VirtualNode):
|
||||
return sample.copy()
|
||||
else:
|
||||
return sample
|
||||
return VirtualNode(None, sample)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
@ -219,14 +235,14 @@ class Integer(Categorical):
|
||||
default = data.index(default)
|
||||
super(Integer, self).__init__(*data, default=default)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._raw_lower,
|
||||
upper=self._raw_upper,
|
||||
default=self._raw_default,
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
@ -286,11 +302,12 @@ class Continuous(Space):
|
||||
del recursion
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
return math.exp(sample)
|
||||
sample = math.exp(sample)
|
||||
else:
|
||||
return random.uniform(self._lower, self._upper)
|
||||
sample = random.uniform(self._lower, self._upper)
|
||||
return VirtualNode(None, sample)
|
||||
|
||||
def xrepr(self, indent=0):
|
||||
def xrepr(self, prefix=""):
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
@ -298,7 +315,7 @@ class Continuous(Space):
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
return " " * indent + xrepr
|
||||
return prefix + xrepr
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
|
@ -1,5 +1,6 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .super_module import SuperRunMode
|
||||
from .super_module import SuperModule
|
||||
from .super_mlp import SuperLinear
|
||||
|
@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
@ -52,14 +53,15 @@ class SuperLinear(SuperModule):
|
||||
def bias(self):
|
||||
return spaces.has_categorical(self._bias, True)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._in_features):
|
||||
root_node.append("_in_features", self._in_features)
|
||||
root_node.append("_in_features", self._in_features.abstract())
|
||||
if not spaces.is_determined(self._out_features):
|
||||
root_node.append("_out_features", self._out_features)
|
||||
root_node.append("_out_features", self._out_features.abstract())
|
||||
if not spaces.is_determined(self._bias):
|
||||
root_node.append("_bias", self._bias)
|
||||
root_node.append("_bias", self._bias.abstract())
|
||||
return root_node
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
@ -69,6 +71,37 @@ class SuperLinear(SuperModule):
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self._super_bias, -bound, bound)
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# check inputs ->
|
||||
if not spaces.is_determined(self._in_features):
|
||||
expected_input_dim = self.abstract_child["_in_features"].value
|
||||
else:
|
||||
expected_input_dim = spaces.get_determined_value(self._in_features)
|
||||
if input.size(-1) != expected_input_dim:
|
||||
raise ValueError(
|
||||
"Expect the input dim of {:} instead of {:}".format(
|
||||
expected_input_dim, input.size(-1)
|
||||
)
|
||||
)
|
||||
# create the weight matrix
|
||||
if not spaces.is_determined(self._out_features):
|
||||
out_dim = self.abstract_child["_out_features"].value
|
||||
else:
|
||||
out_dim = spaces.get_determined_value(self._out_features)
|
||||
candidate_weight = self._super_weight[:out_dim, :expected_input_dim]
|
||||
# create the bias matrix
|
||||
if not spaces.is_determined(self._bias):
|
||||
if self.abstract_child["_bias"].value:
|
||||
candidate_bias = self._super_bias[:out_dim]
|
||||
else:
|
||||
candidate_bias = None
|
||||
else:
|
||||
if spaces.get_determined_value(self._bias):
|
||||
candidate_bias = self._super_bias[:out_dim]
|
||||
else:
|
||||
candidate_bias = None
|
||||
return F.linear(input, candidate_weight, candidate_bias)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self._super_weight, self._super_bias)
|
||||
|
||||
@ -78,8 +111,9 @@ class SuperLinear(SuperModule):
|
||||
)
|
||||
|
||||
|
||||
class SuperMLP(nn.Module):
|
||||
# MLP: FC -> Activation -> Drop -> FC -> Drop
|
||||
class SuperMLP(SuperModule):
|
||||
"""An MLP layer: FC -> Activation -> Drop -> FC -> Drop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
@ -88,13 +122,13 @@ class SuperMLP(nn.Module):
|
||||
act_layer=nn.GELU,
|
||||
drop: Optional[float] = None,
|
||||
):
|
||||
super(MLP, self).__init__()
|
||||
super(SuperMLP, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop or 0)
|
||||
self.drop = nn.Dropout(drop or 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
|
@ -6,11 +6,14 @@ import abc
|
||||
import torch.nn as nn
|
||||
from enum import Enum
|
||||
|
||||
import spaces
|
||||
|
||||
|
||||
class SuperRunMode(Enum):
|
||||
"""This class defines the enumerations for Super Model Running Mode."""
|
||||
|
||||
FullModel = "fullmodel"
|
||||
Candidate = "candidate"
|
||||
Default = "fullmodel"
|
||||
|
||||
|
||||
@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
def __init__(self):
|
||||
super(SuperModule, self).__init__()
|
||||
self._super_run_type = SuperRunMode.Default
|
||||
self._abstract_child = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_super_run_type(self, super_run_type):
|
||||
def _reset_super_run(m):
|
||||
if isinstance(m, SuperModule):
|
||||
m._super_run_type = super_run_type
|
||||
|
||||
self.apply(_reset_super_run)
|
||||
|
||||
def apply_candiate(self, abstract_child):
|
||||
if not isinstance(abstract_child, spaces.VirtualNode):
|
||||
raise ValueError(
|
||||
"Invalid abstract child program: {:}".format(abstract_child)
|
||||
)
|
||||
self._abstract_child = abstract_child
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
def super_run_type(self):
|
||||
return self._super_run_type
|
||||
|
||||
@property
|
||||
def abstract_child(self):
|
||||
return self._abstract_child
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_raw(self, *inputs):
|
||||
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward_candidate(self, *inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *inputs):
|
||||
if self.super_run_type == SuperRunMode.FullModel:
|
||||
return self.forward_raw(*inputs)
|
||||
elif self.super_run_type == SuperRunMode.Candidate:
|
||||
return self.forward_candidate(*inputs)
|
||||
else:
|
||||
raise ModeError(
|
||||
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
||||
|
@ -41,14 +41,14 @@ class TestBasicSpace(unittest.TestCase):
|
||||
def test_continuous(self):
|
||||
random.seed(999)
|
||||
space = Continuous(0, 1)
|
||||
self.assertGreaterEqual(space.random(), 0)
|
||||
self.assertGreaterEqual(1, space.random())
|
||||
self.assertGreaterEqual(space.random().value, 0)
|
||||
self.assertGreaterEqual(1, space.random().value)
|
||||
|
||||
lower, upper = 1.5, 4.6
|
||||
space = Continuous(lower, upper, log=False)
|
||||
values = []
|
||||
for i in range(1000000):
|
||||
x = space.random()
|
||||
x = space.random().value
|
||||
self.assertGreaterEqual(x, lower)
|
||||
self.assertGreaterEqual(upper, x)
|
||||
values.append(x)
|
||||
@ -89,7 +89,7 @@ class TestBasicSpace(unittest.TestCase):
|
||||
Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11),
|
||||
12,
|
||||
)
|
||||
print(nested_space)
|
||||
print("\nThe nested search space:\n{:}".format(nested_space))
|
||||
for i in range(1, 13):
|
||||
self.assertTrue(nested_space.has(i))
|
||||
|
||||
@ -102,6 +102,19 @@ class TestAbstractSpace(unittest.TestCase):
|
||||
"""Test the abstract search spaces."""
|
||||
|
||||
def test_continous(self):
|
||||
print("")
|
||||
space = Continuous(0, 1)
|
||||
self.assertEqual(space, space.abstract())
|
||||
print("The abstract search space for Continuous: {:}".format(space.abstract()))
|
||||
|
||||
space = Categorical(1, 2, 3)
|
||||
self.assertEqual(len(space.abstract()), 3)
|
||||
print(space.abstract())
|
||||
|
||||
nested_space = Categorical(
|
||||
Categorical(1, 2, 3),
|
||||
Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11),
|
||||
12,
|
||||
)
|
||||
abstract_nested_space = nested_space.abstract()
|
||||
print("The abstract nested search space:\n{:}".format(abstract_nested_space))
|
||||
|
@ -25,6 +25,26 @@ class TestSuperLinear(unittest.TestCase):
|
||||
out_features = spaces.Categorical(12, 24, 36)
|
||||
bias = spaces.Categorical(True, False)
|
||||
model = super_core.SuperLinear(10, out_features, bias=bias)
|
||||
print(model)
|
||||
print("The simple super linear module is:\n{:}".format(model))
|
||||
|
||||
print(model.super_run_type)
|
||||
print(model.abstract_search_space())
|
||||
self.assertTrue(model.bias)
|
||||
|
||||
inputs = torch.rand(32, 10)
|
||||
print("Input shape: {:}".format(inputs.shape))
|
||||
print("Weight shape: {:}".format(model._super_weight.shape))
|
||||
print("Bias shape: {:}".format(model._super_bias.shape))
|
||||
outputs = model(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (32, 36))
|
||||
|
||||
abstract_space = model.abstract_search_space
|
||||
abstract_child = abstract_space.random()
|
||||
print("The abstract searc space:\n{:}".format(abstract_space))
|
||||
print("The abstract child program:\n{:}".format(abstract_child))
|
||||
|
||||
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||
model.apply_candiate(abstract_child)
|
||||
|
||||
output_shape = (32, abstract_child["_out_features"].value)
|
||||
outputs = model(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user