Complete Super Linear
This commit is contained in:
		| @@ -12,5 +12,6 @@ from .basic_space import VirtualNode | |||||||
| from .basic_op import has_categorical | from .basic_op import has_categorical | ||||||
| from .basic_op import has_continuous | from .basic_op import has_continuous | ||||||
| from .basic_op import is_determined | from .basic_op import is_determined | ||||||
|  | from .basic_op import get_determined_value | ||||||
| from .basic_op import get_min | from .basic_op import get_min | ||||||
| from .basic_op import get_max | from .basic_op import get_max | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| from spaces.basic_space import Space | from spaces.basic_space import Space | ||||||
|  | from spaces.basic_space import VirtualNode | ||||||
| from spaces.basic_space import Integer | from spaces.basic_space import Integer | ||||||
| from spaces.basic_space import Continuous | from spaces.basic_space import Continuous | ||||||
| from spaces.basic_space import Categorical | from spaces.basic_space import Categorical | ||||||
| @@ -26,6 +27,20 @@ def is_determined(space_or_value): | |||||||
|         return True |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_determined_value(space_or_value): | ||||||
|  |     if not is_determined(space_or_value): | ||||||
|  |         raise ValueError("This input is not determined: {:}".format(space_or_value)) | ||||||
|  |     if isinstance(space_or_value, Space): | ||||||
|  |         if isinstance(space_or_value, Continuous): | ||||||
|  |             return space_or_value.lower | ||||||
|  |         elif isinstance(space_or_value, Categorical): | ||||||
|  |             return get_determined_value(space_or_value[0]) | ||||||
|  |         else:  # VirtualNode | ||||||
|  |             return space_or_value.value | ||||||
|  |     else: | ||||||
|  |         return space_or_value | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_max(space_or_value): | def get_max(space_or_value): | ||||||
|     if isinstance(space_or_value, Integer): |     if isinstance(space_or_value, Integer): | ||||||
|         return max(space_or_value.candidates) |         return max(space_or_value.candidates) | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ class Space(metaclass=abc.ABCMeta): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     @abc.abstractproperty |     @abc.abstractproperty | ||||||
|     def xrepr(self, indent=0) -> Text: |     def xrepr(self, prefix="") -> Text: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self) -> Text: |     def __repr__(self) -> Text: | ||||||
| @@ -67,17 +67,27 @@ class VirtualNode(Space): | |||||||
|         self._value = value |         self._value = value | ||||||
|         self._attributes = OrderedDict() |         self._attributes = OrderedDict() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def value(self): | ||||||
|  |         return self._value | ||||||
|  |  | ||||||
|     def append(self, key, value): |     def append(self, key, value): | ||||||
|  |         if not isinstance(key, str): | ||||||
|  |             raise TypeError( | ||||||
|  |                 "Only accept string as a key instead of {:}".format(type(key)) | ||||||
|  |             ) | ||||||
|         if not isinstance(value, Space): |         if not isinstance(value, Space): | ||||||
|             raise ValueError("Invalid type of value: {:}".format(type(value))) |             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||||
|  |         # if value.determined: | ||||||
|  |         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||||
|         self._attributes[key] = value |         self._attributes[key] = value | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0) -> Text: |     def xrepr(self, prefix="  ") -> Text: | ||||||
|         strs = [self.__class__.__name__ + "("] |         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             strs.append(value.xrepr(indent + 2) + ",") |             strs.append(value.xrepr(prefix + "  " + key + " = ")) | ||||||
|         strs.append(")") |         strs.append(")") | ||||||
|         return "\n".join(strs) |         return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) | ||||||
|  |  | ||||||
|     def abstract(self) -> Space: |     def abstract(self) -> Space: | ||||||
|         node = VirtualNode(id(self)) |         node = VirtualNode(id(self)) | ||||||
| @@ -87,7 +97,10 @@ class VirtualNode(Space): | |||||||
|         return node |         return node | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True): | ||||||
|         raise NotImplementedError |         node = VirtualNode(None, self._value) | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             node.append(key, value.random(recursion)) | ||||||
|  |         return node | ||||||
|  |  | ||||||
|     def has(self, x) -> bool: |     def has(self, x) -> bool: | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
| @@ -101,6 +114,7 @@ class VirtualNode(Space): | |||||||
|     def __getitem__(self, key): |     def __getitem__(self, key): | ||||||
|         return self._attributes[key] |         return self._attributes[key] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|     def determined(self) -> bool: |     def determined(self) -> bool: | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             if not value.determined(x): |             if not value.determined(x): | ||||||
| @@ -165,20 +179,22 @@ class Categorical(Space): | |||||||
|                 data.append(candidate.abstract()) |                 data.append(candidate.abstract()) | ||||||
|             else: |             else: | ||||||
|                 data.append(VirtualNode(id(candidate), candidate)) |                 data.append(VirtualNode(id(candidate), candidate)) | ||||||
|         return Categorical(*data, self._default) |         return Categorical(*data, default=self._default) | ||||||
|  |  | ||||||
|     def random(self, recursion=True): |     def random(self, recursion=True): | ||||||
|         sample = random.choice(self._candidates) |         sample = random.choice(self._candidates) | ||||||
|         if recursion and isinstance(sample, Space): |         if recursion and isinstance(sample, Space): | ||||||
|             return sample.random(recursion) |             sample = sample.random(recursion) | ||||||
|  |         if isinstance(sample, VirtualNode): | ||||||
|  |             return sample.copy() | ||||||
|         else: |         else: | ||||||
|             return sample |             return VirtualNode(None, sample) | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |     def xrepr(self, prefix=""): | ||||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( |         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default |             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||||
|         ) |         ) | ||||||
|         return " " * indent + xrepr |         return prefix + xrepr | ||||||
|  |  | ||||||
|     def has(self, x): |     def has(self, x): | ||||||
|         super().has(x) |         super().has(x) | ||||||
| @@ -219,14 +235,14 @@ class Integer(Categorical): | |||||||
|             default = data.index(default) |             default = data.index(default) | ||||||
|         super(Integer, self).__init__(*data, default=default) |         super(Integer, self).__init__(*data, default=default) | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |     def xrepr(self, prefix=""): | ||||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._raw_lower, |             lower=self._raw_lower, | ||||||
|             upper=self._raw_upper, |             upper=self._raw_upper, | ||||||
|             default=self._raw_default, |             default=self._raw_default, | ||||||
|         ) |         ) | ||||||
|         return " " * indent + xrepr |         return prefix + xrepr | ||||||
|  |  | ||||||
|  |  | ||||||
| np_float_types = (np.float16, np.float32, np.float64) | np_float_types = (np.float16, np.float32, np.float64) | ||||||
| @@ -286,11 +302,12 @@ class Continuous(Space): | |||||||
|         del recursion |         del recursion | ||||||
|         if self._log_scale: |         if self._log_scale: | ||||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) |             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||||
|             return math.exp(sample) |             sample = math.exp(sample) | ||||||
|         else: |         else: | ||||||
|             return random.uniform(self._lower, self._upper) |             sample = random.uniform(self._lower, self._upper) | ||||||
|  |         return VirtualNode(None, sample) | ||||||
|  |  | ||||||
|     def xrepr(self, indent=0): |     def xrepr(self, prefix=""): | ||||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._lower, |             lower=self._lower, | ||||||
| @@ -298,7 +315,7 @@ class Continuous(Space): | |||||||
|             default=self._default, |             default=self._default, | ||||||
|             log=self._log_scale, |             log=self._log_scale, | ||||||
|         ) |         ) | ||||||
|         return " " * indent + xrepr |         return prefix + xrepr | ||||||
|  |  | ||||||
|     def convert(self, x): |     def convert(self, x): | ||||||
|         if isinstance(x, np_float_types) and x.size == 1: |         if isinstance(x, np_float_types) and x.size == 1: | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | from .super_module import SuperRunMode | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_mlp import SuperLinear | from .super_mlp import SuperLinear | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
| import math | import math | ||||||
| from typing import Optional, Union | from typing import Optional, Union | ||||||
| @@ -52,14 +53,15 @@ class SuperLinear(SuperModule): | |||||||
|     def bias(self): |     def bias(self): | ||||||
|         return spaces.has_categorical(self._bias, True) |         return spaces.has_categorical(self._bias, True) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         root_node = spaces.VirtualNode(id(self)) |         root_node = spaces.VirtualNode(id(self)) | ||||||
|         if not spaces.is_determined(self._in_features): |         if not spaces.is_determined(self._in_features): | ||||||
|             root_node.append("_in_features", self._in_features) |             root_node.append("_in_features", self._in_features.abstract()) | ||||||
|         if not spaces.is_determined(self._out_features): |         if not spaces.is_determined(self._out_features): | ||||||
|             root_node.append("_out_features", self._out_features) |             root_node.append("_out_features", self._out_features.abstract()) | ||||||
|         if not spaces.is_determined(self._bias): |         if not spaces.is_determined(self._bias): | ||||||
|             root_node.append("_bias", self._bias) |             root_node.append("_bias", self._bias.abstract()) | ||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def reset_parameters(self) -> None: |     def reset_parameters(self) -> None: | ||||||
| @@ -69,6 +71,37 @@ class SuperLinear(SuperModule): | |||||||
|             bound = 1 / math.sqrt(fan_in) |             bound = 1 / math.sqrt(fan_in) | ||||||
|             nn.init.uniform_(self._super_bias, -bound, bound) |             nn.init.uniform_(self._super_bias, -bound, bound) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |             expected_input_dim = self.abstract_child["_in_features"].value | ||||||
|  |         else: | ||||||
|  |             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||||
|  |         if input.size(-1) != expected_input_dim: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Expect the input dim of {:} instead of {:}".format( | ||||||
|  |                     expected_input_dim, input.size(-1) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         # create the weight matrix | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |             out_dim = self.abstract_child["_out_features"].value | ||||||
|  |         else: | ||||||
|  |             out_dim = spaces.get_determined_value(self._out_features) | ||||||
|  |         candidate_weight = self._super_weight[:out_dim, :expected_input_dim] | ||||||
|  |         # create the bias matrix | ||||||
|  |         if not spaces.is_determined(self._bias): | ||||||
|  |             if self.abstract_child["_bias"].value: | ||||||
|  |                 candidate_bias = self._super_bias[:out_dim] | ||||||
|  |             else: | ||||||
|  |                 candidate_bias = None | ||||||
|  |         else: | ||||||
|  |             if spaces.get_determined_value(self._bias): | ||||||
|  |                 candidate_bias = self._super_bias[:out_dim] | ||||||
|  |             else: | ||||||
|  |                 candidate_bias = None | ||||||
|  |         return F.linear(input, candidate_weight, candidate_bias) | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         return F.linear(input, self._super_weight, self._super_bias) |         return F.linear(input, self._super_weight, self._super_bias) | ||||||
|  |  | ||||||
| @@ -78,8 +111,9 @@ class SuperLinear(SuperModule): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperMLP(nn.Module): | class SuperMLP(SuperModule): | ||||||
|     # MLP: FC -> Activation -> Drop -> FC -> Drop |     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         in_features, |         in_features, | ||||||
| @@ -88,13 +122,13 @@ class SuperMLP(nn.Module): | |||||||
|         act_layer=nn.GELU, |         act_layer=nn.GELU, | ||||||
|         drop: Optional[float] = None, |         drop: Optional[float] = None, | ||||||
|     ): |     ): | ||||||
|         super(MLP, self).__init__() |         super(SuperMLP, self).__init__() | ||||||
|         out_features = out_features or in_features |         out_features = out_features or in_features | ||||||
|         hidden_features = hidden_features or in_features |         hidden_features = hidden_features or in_features | ||||||
|         self.fc1 = nn.Linear(in_features, hidden_features) |         self.fc1 = nn.Linear(in_features, hidden_features) | ||||||
|         self.act = act_layer() |         self.act = act_layer() | ||||||
|         self.fc2 = nn.Linear(hidden_features, out_features) |         self.fc2 = nn.Linear(hidden_features, out_features) | ||||||
|         self.drop = nn.Dropout(drop or 0) |         self.drop = nn.Dropout(drop or 0.0) | ||||||
|  |  | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         x = self.fc1(x) |         x = self.fc1(x) | ||||||
|   | |||||||
| @@ -6,11 +6,14 @@ import abc | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from enum import Enum | from enum import Enum | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperRunMode(Enum): | class SuperRunMode(Enum): | ||||||
|     """This class defines the enumerations for Super Model Running Mode.""" |     """This class defines the enumerations for Super Model Running Mode.""" | ||||||
|  |  | ||||||
|     FullModel = "fullmodel" |     FullModel = "fullmodel" | ||||||
|  |     Candidate = "candidate" | ||||||
|     Default = "fullmodel" |     Default = "fullmodel" | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -20,8 +23,23 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super(SuperModule, self).__init__() |         super(SuperModule, self).__init__() | ||||||
|         self._super_run_type = SuperRunMode.Default |         self._super_run_type = SuperRunMode.Default | ||||||
|  |         self._abstract_child = None | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     def set_super_run_type(self, super_run_type): | ||||||
|  |         def _reset_super_run(m): | ||||||
|  |             if isinstance(m, SuperModule): | ||||||
|  |                 m._super_run_type = super_run_type | ||||||
|  |  | ||||||
|  |         self.apply(_reset_super_run) | ||||||
|  |  | ||||||
|  |     def apply_candiate(self, abstract_child): | ||||||
|  |         if not isinstance(abstract_child, spaces.VirtualNode): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Invalid abstract child program: {:}".format(abstract_child) | ||||||
|  |             ) | ||||||
|  |         self._abstract_child = abstract_child | ||||||
|  |  | ||||||
|  |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
| @@ -29,13 +47,24 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|     def super_run_type(self): |     def super_run_type(self): | ||||||
|         return self._super_run_type |         return self._super_run_type | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_child(self): | ||||||
|  |         return self._abstract_child | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def forward_raw(self, *inputs): |     def forward_raw(self, *inputs): | ||||||
|  |         """Use the largest candidate for forward. Similar to the original PyTorch model.""" | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def forward_candidate(self, *inputs): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def forward(self, *inputs): |     def forward(self, *inputs): | ||||||
|         if self.super_run_type == SuperRunMode.FullModel: |         if self.super_run_type == SuperRunMode.FullModel: | ||||||
|             return self.forward_raw(*inputs) |             return self.forward_raw(*inputs) | ||||||
|  |         elif self.super_run_type == SuperRunMode.Candidate: | ||||||
|  |             return self.forward_candidate(*inputs) | ||||||
|         else: |         else: | ||||||
|             raise ModeError( |             raise ModeError( | ||||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) |                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||||
|   | |||||||
| @@ -41,14 +41,14 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|     def test_continuous(self): |     def test_continuous(self): | ||||||
|         random.seed(999) |         random.seed(999) | ||||||
|         space = Continuous(0, 1) |         space = Continuous(0, 1) | ||||||
|         self.assertGreaterEqual(space.random(), 0) |         self.assertGreaterEqual(space.random().value, 0) | ||||||
|         self.assertGreaterEqual(1, space.random()) |         self.assertGreaterEqual(1, space.random().value) | ||||||
|  |  | ||||||
|         lower, upper = 1.5, 4.6 |         lower, upper = 1.5, 4.6 | ||||||
|         space = Continuous(lower, upper, log=False) |         space = Continuous(lower, upper, log=False) | ||||||
|         values = [] |         values = [] | ||||||
|         for i in range(1000000): |         for i in range(1000000): | ||||||
|             x = space.random() |             x = space.random().value | ||||||
|             self.assertGreaterEqual(x, lower) |             self.assertGreaterEqual(x, lower) | ||||||
|             self.assertGreaterEqual(upper, x) |             self.assertGreaterEqual(upper, x) | ||||||
|             values.append(x) |             values.append(x) | ||||||
| @@ -89,7 +89,7 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), |             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||||
|             12, |             12, | ||||||
|         ) |         ) | ||||||
|         print(nested_space) |         print("\nThe nested search space:\n{:}".format(nested_space)) | ||||||
|         for i in range(1, 13): |         for i in range(1, 13): | ||||||
|             self.assertTrue(nested_space.has(i)) |             self.assertTrue(nested_space.has(i)) | ||||||
|  |  | ||||||
| @@ -102,6 +102,19 @@ class TestAbstractSpace(unittest.TestCase): | |||||||
|     """Test the abstract search spaces.""" |     """Test the abstract search spaces.""" | ||||||
|  |  | ||||||
|     def test_continous(self): |     def test_continous(self): | ||||||
|  |         print("") | ||||||
|         space = Continuous(0, 1) |         space = Continuous(0, 1) | ||||||
|         self.assertEqual(space, space.abstract()) |         self.assertEqual(space, space.abstract()) | ||||||
|  |         print("The abstract search space for Continuous: {:}".format(space.abstract())) | ||||||
|  |  | ||||||
|  |         space = Categorical(1, 2, 3) | ||||||
|  |         self.assertEqual(len(space.abstract()), 3) | ||||||
|         print(space.abstract()) |         print(space.abstract()) | ||||||
|  |  | ||||||
|  |         nested_space = Categorical( | ||||||
|  |             Categorical(1, 2, 3), | ||||||
|  |             Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11), | ||||||
|  |             12, | ||||||
|  |         ) | ||||||
|  |         abstract_nested_space = nested_space.abstract() | ||||||
|  |         print("The abstract nested search space:\n{:}".format(abstract_nested_space)) | ||||||
|   | |||||||
| @@ -25,6 +25,26 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         out_features = spaces.Categorical(12, 24, 36) |         out_features = spaces.Categorical(12, 24, 36) | ||||||
|         bias = spaces.Categorical(True, False) |         bias = spaces.Categorical(True, False) | ||||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) |         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||||
|         print(model) |         print("The simple super linear module is:\n{:}".format(model)) | ||||||
|  |  | ||||||
|         print(model.super_run_type) |         print(model.super_run_type) | ||||||
|         print(model.abstract_search_space()) |         self.assertTrue(model.bias) | ||||||
|  |  | ||||||
|  |         inputs = torch.rand(32, 10) | ||||||
|  |         print("Input shape: {:}".format(inputs.shape)) | ||||||
|  |         print("Weight shape: {:}".format(model._super_weight.shape)) | ||||||
|  |         print("Bias shape: {:}".format(model._super_bias.shape)) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), (32, 36)) | ||||||
|  |  | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         abstract_child = abstract_space.random() | ||||||
|  |         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||||
|  |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|  |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.apply_candiate(abstract_child) | ||||||
|  |  | ||||||
|  |         output_shape = (32, abstract_child["_out_features"].value) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user