Complete Super Linear
This commit is contained in:
		| @@ -12,5 +12,6 @@ from .basic_space import VirtualNode | ||||
| from .basic_op import has_categorical | ||||
| from .basic_op import has_continuous | ||||
| from .basic_op import is_determined | ||||
| from .basic_op import get_determined_value | ||||
| from .basic_op import get_min | ||||
| from .basic_op import get_max | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| from spaces.basic_space import Space | ||||
| from spaces.basic_space import VirtualNode | ||||
| from spaces.basic_space import Integer | ||||
| from spaces.basic_space import Continuous | ||||
| from spaces.basic_space import Categorical | ||||
| @@ -26,6 +27,20 @@ def is_determined(space_or_value): | ||||
|         return True | ||||
|  | ||||
|  | ||||
| def get_determined_value(space_or_value): | ||||
|     if not is_determined(space_or_value): | ||||
|         raise ValueError("This input is not determined: {:}".format(space_or_value)) | ||||
|     if isinstance(space_or_value, Space): | ||||
|         if isinstance(space_or_value, Continuous): | ||||
|             return space_or_value.lower | ||||
|         elif isinstance(space_or_value, Categorical): | ||||
|             return get_determined_value(space_or_value[0]) | ||||
|         else:  # VirtualNode | ||||
|             return space_or_value.value | ||||
|     else: | ||||
|         return space_or_value | ||||
|  | ||||
|  | ||||
| def get_max(space_or_value): | ||||
|     if isinstance(space_or_value, Integer): | ||||
|         return max(space_or_value.candidates) | ||||
|   | ||||
| @@ -23,7 +23,7 @@ class Space(metaclass=abc.ABCMeta): | ||||
|     """ | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, indent=0) -> Text: | ||||
|     def xrepr(self, prefix="") -> Text: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self) -> Text: | ||||
| @@ -67,17 +67,27 @@ class VirtualNode(Space): | ||||
|         self._value = value | ||||
|         self._attributes = OrderedDict() | ||||
|  | ||||
|     @property | ||||
|     def value(self): | ||||
|         return self._value | ||||
|  | ||||
|     def append(self, key, value): | ||||
|         if not isinstance(key, str): | ||||
|             raise TypeError( | ||||
|                 "Only accept string as a key instead of {:}".format(type(key)) | ||||
|             ) | ||||
|         if not isinstance(value, Space): | ||||
|             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||
|         # if value.determined: | ||||
|         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||
|         self._attributes[key] = value | ||||
|  | ||||
|     def xrepr(self, indent=0) -> Text: | ||||
|         strs = [self.__class__.__name__ + "("] | ||||
|     def xrepr(self, prefix="  ") -> Text: | ||||
|         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||
|         for key, value in self._attributes.items(): | ||||
|             strs.append(value.xrepr(indent + 2) + ",") | ||||
|             strs.append(value.xrepr(prefix + "  " + key + " = ")) | ||||
|         strs.append(")") | ||||
|         return "\n".join(strs) | ||||
|         return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         node = VirtualNode(id(self)) | ||||
| @@ -87,7 +97,10 @@ class VirtualNode(Space): | ||||
|         return node | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         raise NotImplementedError | ||||
|         node = VirtualNode(None, self._value) | ||||
|         for key, value in self._attributes.items(): | ||||
|             node.append(key, value.random(recursion)) | ||||
|         return node | ||||
|  | ||||
|     def has(self, x) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
| @@ -101,6 +114,7 @@ class VirtualNode(Space): | ||||
|     def __getitem__(self, key): | ||||
|         return self._attributes[key] | ||||
|  | ||||
|     @property | ||||
|     def determined(self) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined(x): | ||||
| @@ -165,20 +179,22 @@ class Categorical(Space): | ||||
|                 data.append(candidate.abstract()) | ||||
|             else: | ||||
|                 data.append(VirtualNode(id(candidate), candidate)) | ||||
|         return Categorical(*data, self._default) | ||||
|         return Categorical(*data, default=self._default) | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             return sample.random(recursion) | ||||
|             sample = sample.random(recursion) | ||||
|         if isinstance(sample, VirtualNode): | ||||
|             return sample.copy() | ||||
|         else: | ||||
|             return sample | ||||
|             return VirtualNode(None, sample) | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|     def xrepr(self, prefix=""): | ||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
|         ) | ||||
|         return " " * indent + xrepr | ||||
|         return prefix + xrepr | ||||
|  | ||||
|     def has(self, x): | ||||
|         super().has(x) | ||||
| @@ -219,14 +235,14 @@ class Integer(Categorical): | ||||
|             default = data.index(default) | ||||
|         super(Integer, self).__init__(*data, default=default) | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|     def xrepr(self, prefix=""): | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._raw_lower, | ||||
|             upper=self._raw_upper, | ||||
|             default=self._raw_default, | ||||
|         ) | ||||
|         return " " * indent + xrepr | ||||
|         return prefix + xrepr | ||||
|  | ||||
|  | ||||
| np_float_types = (np.float16, np.float32, np.float64) | ||||
| @@ -286,11 +302,12 @@ class Continuous(Space): | ||||
|         del recursion | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             return math.exp(sample) | ||||
|             sample = math.exp(sample) | ||||
|         else: | ||||
|             return random.uniform(self._lower, self._upper) | ||||
|             sample = random.uniform(self._lower, self._upper) | ||||
|         return VirtualNode(None, sample) | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|     def xrepr(self, prefix=""): | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._lower, | ||||
| @@ -298,7 +315,7 @@ class Continuous(Space): | ||||
|             default=self._default, | ||||
|             log=self._log_scale, | ||||
|         ) | ||||
|         return " " * indent + xrepr | ||||
|         return prefix + xrepr | ||||
|  | ||||
|     def convert(self, x): | ||||
|         if isinstance(x, np_float_types) and x.size == 1: | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from .super_module import SuperRunMode | ||||
| from .super_module import SuperModule | ||||
| from .super_mlp import SuperLinear | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Union | ||||
| @@ -52,14 +53,15 @@ class SuperLinear(SuperModule): | ||||
|     def bias(self): | ||||
|         return spaces.has_categorical(self._bias, True) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append("_in_features", self._in_features) | ||||
|             root_node.append("_in_features", self._in_features.abstract()) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append("_out_features", self._out_features) | ||||
|             root_node.append("_out_features", self._out_features.abstract()) | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             root_node.append("_bias", self._bias) | ||||
|             root_node.append("_bias", self._bias.abstract()) | ||||
|         return root_node | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
| @@ -69,6 +71,37 @@ class SuperLinear(SuperModule): | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             nn.init.uniform_(self._super_bias, -bound, bound) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             expected_input_dim = self.abstract_child["_in_features"].value | ||||
|         else: | ||||
|             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||
|         if input.size(-1) != expected_input_dim: | ||||
|             raise ValueError( | ||||
|                 "Expect the input dim of {:} instead of {:}".format( | ||||
|                     expected_input_dim, input.size(-1) | ||||
|                 ) | ||||
|             ) | ||||
|         # create the weight matrix | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             out_dim = self.abstract_child["_out_features"].value | ||||
|         else: | ||||
|             out_dim = spaces.get_determined_value(self._out_features) | ||||
|         candidate_weight = self._super_weight[:out_dim, :expected_input_dim] | ||||
|         # create the bias matrix | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             if self.abstract_child["_bias"].value: | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         else: | ||||
|             if spaces.get_determined_value(self._bias): | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         return F.linear(input, candidate_weight, candidate_bias) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.linear(input, self._super_weight, self._super_bias) | ||||
|  | ||||
| @@ -78,8 +111,9 @@ class SuperLinear(SuperModule): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperMLP(nn.Module): | ||||
|     # MLP: FC -> Activation -> Drop -> FC -> Drop | ||||
| class SuperMLP(SuperModule): | ||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
| @@ -88,13 +122,13 @@ class SuperMLP(nn.Module): | ||||
|         act_layer=nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(MLP, self).__init__() | ||||
|         super(SuperMLP, self).__init__() | ||||
|         out_features = out_features or in_features | ||||
|         hidden_features = hidden_features or in_features | ||||
|         self.fc1 = nn.Linear(in_features, hidden_features) | ||||
|         self.act = act_layer() | ||||
|         self.fc2 = nn.Linear(hidden_features, out_features) | ||||
|         self.drop = nn.Dropout(drop or 0) | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|   | ||||
| @@ -6,11 +6,14 @@ import abc | ||||
| import torch.nn as nn | ||||
| from enum import Enum | ||||
|  | ||||
| import spaces | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
|  | ||||
|     FullModel = "fullmodel" | ||||
|     Candidate = "candidate" | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| @@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def __init__(self): | ||||
|         super(SuperModule, self).__init__() | ||||
|         self._super_run_type = SuperRunMode.Default | ||||
|         self._abstract_child = None | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def set_super_run_type(self, super_run_type): | ||||
|         def _reset_super_run(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|                 m._super_run_type = super_run_type | ||||
|  | ||||
|         self.apply(_reset_super_run) | ||||
|  | ||||
|     def apply_candiate(self, abstract_child): | ||||
|         if not isinstance(abstract_child, spaces.VirtualNode): | ||||
|             raise ValueError( | ||||
|                 "Invalid abstract child program: {:}".format(abstract_child) | ||||
|             ) | ||||
|         self._abstract_child = abstract_child | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
| @@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def super_run_type(self): | ||||
|         return self._super_run_type | ||||
|  | ||||
|     @property | ||||
|     def abstract_child(self): | ||||
|         return self._abstract_child | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_raw(self, *inputs): | ||||
|         """Use the largest candidate for forward. Similar to the original PyTorch model.""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_candidate(self, *inputs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward(self, *inputs): | ||||
|         if self.super_run_type == SuperRunMode.FullModel: | ||||
|             return self.forward_raw(*inputs) | ||||
|         elif self.super_run_type == SuperRunMode.Candidate: | ||||
|             return self.forward_candidate(*inputs) | ||||
|         else: | ||||
|             raise ModeError( | ||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||
|   | ||||
| @@ -41,14 +41,14 @@ class TestBasicSpace(unittest.TestCase): | ||||
|     def test_continuous(self): | ||||
|         random.seed(999) | ||||
|         space = Continuous(0, 1) | ||||
|         self.assertGreaterEqual(space.random(), 0) | ||||
|         self.assertGreaterEqual(1, space.random()) | ||||
|         self.assertGreaterEqual(space.random().value, 0) | ||||
|         self.assertGreaterEqual(1, space.random().value) | ||||
|  | ||||
|         lower, upper = 1.5, 4.6 | ||||
|         space = Continuous(lower, upper, log=False) | ||||
|         values = [] | ||||
|         for i in range(1000000): | ||||
|             x = space.random() | ||||
|             x = space.random().value | ||||
|             self.assertGreaterEqual(x, lower) | ||||
|             self.assertGreaterEqual(upper, x) | ||||
|             values.append(x) | ||||
| @@ -89,7 +89,7 @@ class TestBasicSpace(unittest.TestCase): | ||||
|             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||
|             12, | ||||
|         ) | ||||
|         print(nested_space) | ||||
|         print("\nThe nested search space:\n{:}".format(nested_space)) | ||||
|         for i in range(1, 13): | ||||
|             self.assertTrue(nested_space.has(i)) | ||||
|  | ||||
| @@ -102,6 +102,19 @@ class TestAbstractSpace(unittest.TestCase): | ||||
|     """Test the abstract search spaces.""" | ||||
|  | ||||
|     def test_continous(self): | ||||
|         print("") | ||||
|         space = Continuous(0, 1) | ||||
|         self.assertEqual(space, space.abstract()) | ||||
|         print("The abstract search space for Continuous: {:}".format(space.abstract())) | ||||
|  | ||||
|         space = Categorical(1, 2, 3) | ||||
|         self.assertEqual(len(space.abstract()), 3) | ||||
|         print(space.abstract()) | ||||
|  | ||||
|         nested_space = Categorical( | ||||
|             Categorical(1, 2, 3), | ||||
|             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||
|             12, | ||||
|         ) | ||||
|         abstract_nested_space = nested_space.abstract() | ||||
|         print("The abstract nested search space:\n{:}".format(abstract_nested_space)) | ||||
|   | ||||
| @@ -25,6 +25,26 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         out_features = spaces.Categorical(12, 24, 36) | ||||
|         bias = spaces.Categorical(True, False) | ||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||
|         print(model) | ||||
|         print("The simple super linear module is:\n{:}".format(model)) | ||||
|  | ||||
|         print(model.super_run_type) | ||||
|         print(model.abstract_search_space()) | ||||
|         self.assertTrue(model.bias) | ||||
|  | ||||
|         inputs = torch.rand(32, 10) | ||||
|         print("Input shape: {:}".format(inputs.shape)) | ||||
|         print("Weight shape: {:}".format(model._super_weight.shape)) | ||||
|         print("Bias shape: {:}".format(model._super_bias.shape)) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (32, 36)) | ||||
|  | ||||
|         abstract_space = model.abstract_search_space | ||||
|         abstract_child = abstract_space.random() | ||||
|         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.apply_candiate(abstract_child) | ||||
|  | ||||
|         output_shape = (32, abstract_child["_out_features"].value) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user