Update super cores
This commit is contained in:
		
							
								
								
									
										5
									
								
								lib/layers/super_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								lib/layers/super_core.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from .super_module import SuperModule | ||||
| from .super_mlp import SuperLinear | ||||
| @@ -1,38 +1,71 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| from torch.nn.parameter import Parameter | ||||
| from typing import Optional | ||||
| from torch import Tensor | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import spaces | ||||
| from layers.super_module import SuperModule | ||||
| from layers.super_module import SuperModule | ||||
| from layers.super_module import SuperRunType | ||||
|  | ||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||
| BoolSpaceType = Union[bool, spaces.Categorical] | ||||
|  | ||||
|  | ||||
| class SuperLinear(SuperModule): | ||||
|     """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`""" | ||||
|  | ||||
|     def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         bias: BoolSpaceType = True, | ||||
|     ) -> None: | ||||
|         super(SuperLinear, self).__init__() | ||||
|         self.in_features = in_features | ||||
|         self.out_features = out_features | ||||
|         self.weight = Parameter(torch.Tensor(out_features, in_features)) | ||||
|  | ||||
|         # the raw input args | ||||
|         self._in_features = in_features | ||||
|         self._out_features = out_features | ||||
|         self._bias = bias | ||||
|  | ||||
|         self._super_weight = Parameter( | ||||
|             torch.Tensor(self.out_features, self.in_features) | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = Parameter(torch.Tensor(out_features)) | ||||
|             self._super_bias = Parameter(torch.Tensor(self.out_features)) | ||||
|         else: | ||||
|             self.register_parameter("bias", None) | ||||
|             self.register_parameter("_super_bias", None) | ||||
|         self.reset_parameters() | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
|         init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | ||||
|         if self.bias is not None: | ||||
|             fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             init.uniform_(self.bias, -bound, bound) | ||||
|     @property | ||||
|     def in_features(self): | ||||
|         return spaces.get_max(self._in_features) | ||||
|  | ||||
|     def forward(self, input: Tensor) -> Tensor: | ||||
|         return F.linear(input, self.weight, self.bias) | ||||
|     @property | ||||
|     def out_features(self): | ||||
|         return spaces.get_max(self._out_features) | ||||
|  | ||||
|     @property | ||||
|     def bias(self): | ||||
|         return spaces.has_categorical(self._bias, True) | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
|         nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) | ||||
|         if self.bias: | ||||
|             fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight) | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             nn.init.uniform_(self._super_bias, -bound, bound) | ||||
|  | ||||
|     def forward_raw(self, input: Tensor) -> Tensor: | ||||
|         return F.linear(input, self._super_weight, self._super_bias) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, out_features={:}, bias={:}".format( | ||||
|             self.in_features, self.out_features, self.bias is not None | ||||
|             self.in_features, self.out_features, self.bias | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -4,6 +4,14 @@ | ||||
|  | ||||
| import abc | ||||
| import torch.nn as nn | ||||
| from enum import Enum | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
|  | ||||
|     FullModel = "fullmodel" | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| class SuperModule(abc.ABCMeta, nn.Module): | ||||
| @@ -11,7 +19,24 @@ class SuperModule(abc.ABCMeta, nn.Module): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super(SuperModule, self).__init__() | ||||
|         self._super_run_type = SuperRunMode.default | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     def super_run_type(self): | ||||
|         return self._super_run_type | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_raw(self, *inputs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward(self, *inputs): | ||||
|         if self.super_run_type == SuperRunMode.FullModel: | ||||
|             return self.forward_raw(*inputs) | ||||
|         else: | ||||
|             raise ModeError( | ||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||
|             ) | ||||
|   | ||||
| @@ -9,3 +9,5 @@ from .basic_space import Continuous | ||||
| from .basic_space import Integer | ||||
| from .basic_op import has_categorical | ||||
| from .basic_op import has_continuous | ||||
| from .basic_op import get_min | ||||
| from .basic_op import get_max | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| from spaces.basic_space import Space | ||||
| from spaces.basic_space import Integer | ||||
| from spaces.basic_space import Continuous | ||||
| from spaces.basic_space import Categorical | ||||
| from spaces.basic_space import _EPS | ||||
|  | ||||
|  | ||||
| @@ -14,3 +17,33 @@ def has_continuous(space_or_value, x): | ||||
|         return space_or_value.has(x) | ||||
|     else: | ||||
|         return abs(space_or_value - x) <= _EPS | ||||
|  | ||||
|  | ||||
| def get_max(space_or_value): | ||||
|     if isinstance(space_or_value, Integer): | ||||
|         return max(space_or_value.candidates) | ||||
|     elif isinstance(space_or_value, Continuous): | ||||
|         return space_or_value.upper | ||||
|     elif isinstance(space_or_value, Categorical): | ||||
|         values = [] | ||||
|         for index in range(len(space_or_value)): | ||||
|             max_value = get_max(space_or_value[index]) | ||||
|             values.append(max_value) | ||||
|         return max(values) | ||||
|     else: | ||||
|         return space_or_value | ||||
|  | ||||
|  | ||||
| def get_min(space_or_value): | ||||
|     if isinstance(space_or_value, Integer): | ||||
|         return min(space_or_value.candidates) | ||||
|     elif isinstance(space_or_value, Continuous): | ||||
|         return space_or_value.lower | ||||
|     elif isinstance(space_or_value, Categorical): | ||||
|         values = [] | ||||
|         for index in range(len(space_or_value)): | ||||
|             min_value = get_min(space_or_value[index]) | ||||
|             values.append(min_value) | ||||
|         return min(values) | ||||
|     else: | ||||
|         return space_or_value | ||||
|   | ||||
| @@ -10,6 +10,9 @@ import numpy as np | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
|  | ||||
| __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] | ||||
|  | ||||
| _EPS = 1e-9 | ||||
|  | ||||
|  | ||||
| @@ -54,6 +57,10 @@ class Categorical(Space): | ||||
|         ), "default >= {:}".format(len(self._candidates)) | ||||
|         assert len(self) > 0, "Please provide at least one candidate" | ||||
|  | ||||
|     @property | ||||
|     def candidates(self): | ||||
|         return self._candidates | ||||
|  | ||||
|     @property | ||||
|     def determined(self): | ||||
|         if len(self) == 1: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user