Add SuperSimpleNorm and update synthetic env
This commit is contained in:
		| @@ -91,6 +91,8 @@ class SuperSequential(SuperModule): | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         for index, module in enumerate(self): | ||||
|             if not isinstance(module, SuperModule): | ||||
|                 continue | ||||
|             space = module.abstract_search_space | ||||
|             if not spaces.is_determined(space): | ||||
|                 root_node.append(str(index), space) | ||||
| @@ -98,9 +100,9 @@ class SuperSequential(SuperModule): | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperSequential, self).apply_candidate(abstract_child) | ||||
|         for index in range(len(self)): | ||||
|         for index, module in enumerate(self): | ||||
|             if str(index) in abstract_child: | ||||
|                 self.__getitem__(index).apply_candidate(abstract_child[str(index)]) | ||||
|                 module.apply_candidate(abstract_child[str(index)]) | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         return self.forward_raw(input) | ||||
|   | ||||
| @@ -9,6 +9,7 @@ from .super_module import SuperModule | ||||
| from .super_container import SuperSequential | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLPv1, SuperMLPv2 | ||||
| from .super_norm import SuperSimpleNorm | ||||
| from .super_norm import SuperLayerNorm1D | ||||
| from .super_attention import SuperAttention | ||||
| from .super_transformer import SuperTransformerEncoderLayer | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| import warnings | ||||
| from typing import Optional, Union, Callable | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|  | ||||
|         self.apply(_reset_super_run) | ||||
|  | ||||
|     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||||
|         if not isinstance(module, SuperModule): | ||||
|             warnings.warn( | ||||
|                 "Add {:} module, which is not SuperModule, into {:}".format( | ||||
|                     name, self.__class__.__name__ | ||||
|                 ) | ||||
|                 + "\n" | ||||
|                 + "It may cause some functions invalid." | ||||
|             ) | ||||
|         super(SuperModule, self).add_module(name, module) | ||||
|  | ||||
|     def apply_verbose(self, verbose): | ||||
|         def _reset_verbose(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|   | ||||
| @@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule): | ||||
|                 elementwise_affine=self._elementwise_affine, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperSimpleNorm(SuperModule): | ||||
|     """Super simple normalization.""" | ||||
|  | ||||
|     def __init__(self, mean, std, inplace=False) -> None: | ||||
|         super(SuperSimpleNorm, self).__init__() | ||||
|         self._mean = mean | ||||
|         self._std = std | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         if not self._inplace: | ||||
|             tensor = input.clone() | ||||
|         else: | ||||
|             tensor = input | ||||
|         mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device) | ||||
|         std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device) | ||||
|         if (std == 0).any(): | ||||
|             raise ValueError( | ||||
|                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( | ||||
|                     dtype | ||||
|                 ) | ||||
|             ) | ||||
|         while mean.ndim < tensor.ndim: | ||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "mean={mean}, std={mean}, inplace={inplace}".format( | ||||
|             mean=self._mean, std=self._std, inplace=self._inplace | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user