Add super/norm layers in xcore
This commit is contained in:
		| @@ -9,13 +9,27 @@ from .super_module import SuperModule | ||||
| from .super_container import SuperSequential | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLPv1, SuperMLPv2 | ||||
|  | ||||
| from .super_norm import SuperSimpleNorm | ||||
| from .super_norm import SuperLayerNorm1D | ||||
| from .super_norm import SuperSimpleLearnableNorm | ||||
| from .super_norm import SuperIdentity | ||||
|  | ||||
| super_name2norm = { | ||||
|     "simple_norm": SuperSimpleNorm, | ||||
|     "simple_learn_norm": SuperSimpleLearnableNorm, | ||||
|     "layer_norm_1d": SuperLayerNorm1D, | ||||
|     "identity": SuperIdentity, | ||||
| } | ||||
|  | ||||
| from .super_attention import SuperAttention | ||||
| from .super_transformer import SuperTransformerEncoderLayer | ||||
|  | ||||
| from .super_activations import SuperReLU | ||||
| from .super_activations import SuperLeakyReLU | ||||
|  | ||||
| super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} | ||||
|  | ||||
|  | ||||
| from .super_trade_stem import SuperAlphaEBDv1 | ||||
| from .super_positional_embedding import SuperPositionalEncoder | ||||
|   | ||||
| @@ -30,6 +30,45 @@ class SuperRunMode(Enum): | ||||
|     Default = "fullmodel" | ||||
|  | ||||
|  | ||||
| class TensorContainer: | ||||
|     """A class to maintain both parameters and buffers for a model.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._names = [] | ||||
|         self._tensors = [] | ||||
|         self._param_or_buffers = [] | ||||
|         self._name2index = dict() | ||||
|  | ||||
|     def append(self, name, tensor, param_or_buffer): | ||||
|         if not isinstance(tensor, torch.Tensor): | ||||
|             raise TypeError( | ||||
|                 "The input tensor must be torch.Tensor instead of {:}".format( | ||||
|                     type(tensor) | ||||
|                 ) | ||||
|             ) | ||||
|         self._names.append(name) | ||||
|         self._tensors.append(tensor) | ||||
|         self._param_or_buffers.append(param_or_buffer) | ||||
|         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||
|             name | ||||
|         ) | ||||
|         self._name2index[name] = len(self._names) - 1 | ||||
|  | ||||
|     def numel(self): | ||||
|         total = 0 | ||||
|         for tensor in self._tensors: | ||||
|             total += tensor.numel() | ||||
|         return total | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._names) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} tensors)".format( | ||||
|             name=self.__class__.__name__, num=len(self) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperModule(abc.ABC, nn.Module): | ||||
|     """This class equips the nn.Module class with the ability to apply AutoDL.""" | ||||
|  | ||||
| @@ -71,6 +110,14 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|             ) | ||||
|         self._abstract_child = abstract_child | ||||
|  | ||||
|     def named_parameters_buffers(self): | ||||
|         container = TensorContainer() | ||||
|         for name, param in self.named_parameters(): | ||||
|             container.append(name, param, True) | ||||
|         for name, buf in self.named_buffers(): | ||||
|             container.append(name, buf, False) | ||||
|         return container | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         raise NotImplementedError | ||||
|   | ||||
| @@ -89,8 +89,8 @@ class SuperSimpleNorm(SuperModule): | ||||
|  | ||||
|     def __init__(self, mean, std, inplace=False) -> None: | ||||
|         super(SuperSimpleNorm, self).__init__() | ||||
|         self._mean = mean | ||||
|         self._std = std | ||||
|         self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float)) | ||||
|         self.register_buffer("_std", torch.tensor(std, dtype=torch.float)) | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
| @@ -111,7 +111,7 @@ class SuperSimpleNorm(SuperModule): | ||||
|         if (std == 0).any(): | ||||
|             raise ValueError( | ||||
|                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( | ||||
|                     dtype | ||||
|                     tensor.dtype | ||||
|                 ) | ||||
|             ) | ||||
|         while mean.ndim < tensor.ndim: | ||||
| @@ -119,6 +119,75 @@ class SuperSimpleNorm(SuperModule): | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "mean={mean}, std={mean}, inplace={inplace}".format( | ||||
|             mean=self._mean, std=self._std, inplace=self._inplace | ||||
|         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||
|             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperSimpleLearnableNorm(SuperModule): | ||||
|     """Super simple normalization.""" | ||||
|  | ||||
|     def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None: | ||||
|         super(SuperSimpleLearnableNorm, self).__init__() | ||||
|         self.register_parameter( | ||||
|             "_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float)) | ||||
|         ) | ||||
|         self.register_parameter( | ||||
|             "_std", nn.Parameter(torch.tensor(std, dtype=torch.float)) | ||||
|         ) | ||||
|         self._eps = eps | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         if not self._inplace: | ||||
|             tensor = input.clone() | ||||
|         else: | ||||
|             tensor = input | ||||
|         mean, std = ( | ||||
|             self._mean.to(tensor.device), | ||||
|             torch.abs(self._std.to(tensor.device)) + self._eps, | ||||
|         ) | ||||
|         if (std == 0).any(): | ||||
|             raise ValueError("std leads to division by zero.") | ||||
|         while mean.ndim < tensor.ndim: | ||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||
|             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperIdentity(SuperModule): | ||||
|     """Super identity mapping layer.""" | ||||
|  | ||||
|     def __init__(self, inplace=False, **kwargs) -> None: | ||||
|         super(SuperIdentity, self).__init__() | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         if not self._inplace: | ||||
|             tensor = input.clone() | ||||
|         else: | ||||
|             tensor = input | ||||
|         return tensor | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "inplace={inplace}".format(inplace=self._inplace) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user