Update ViT
This commit is contained in:
		
							
								
								
									
										29
									
								
								tests/test_super_vit.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_super_vit.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest ./tests/test_super_vit.py -s               # | ||||||
|  | ##################################################### | ||||||
|  | import sys | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from xautodl.xmodels import transformers | ||||||
|  | from xautodl.utils.flop_benchmark import count_parameters | ||||||
|  |  | ||||||
|  | class TestSuperViT(unittest.TestCase): | ||||||
|  |     """Test the super re-arrange layer.""" | ||||||
|  |  | ||||||
|  |     def test_super_vit(self): | ||||||
|  |         model = transformers.get_transformer("vit-base") | ||||||
|  |         tensor = torch.rand((16, 3, 256, 256)) | ||||||
|  |         print("The tensor shape: {:}".format(tensor.shape)) | ||||||
|  |         print(model) | ||||||
|  |         outs = model(tensor) | ||||||
|  |         print("The output tensor shape: {:}".format(outs.shape)) | ||||||
|  |  | ||||||
|  |     def test_model_size(self): | ||||||
|  |         name2config = transformers.name2config | ||||||
|  |         for name, config in name2config.items(): | ||||||
|  |             model = transformers.get_transformer(config) | ||||||
|  |             size = count_parameters(model, "mb", True) | ||||||
|  |             print('{:10s} : size={:.2f}MB'.format(name, size)) | ||||||
							
								
								
									
										319
									
								
								xautodl/xlayers/super_mlp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								xautodl/xlayers/super_mlp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,319 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | from typing import Optional, Callable | ||||||
|  |  | ||||||
|  | from xautodl import spaces | ||||||
|  | from .super_module import SuperModule | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  | from .super_module import BoolSpaceType | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperLinear(SuperModule): | ||||||
|  |     """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_features: IntSpaceType, | ||||||
|  |         out_features: IntSpaceType, | ||||||
|  |         bias: BoolSpaceType = True, | ||||||
|  |     ) -> None: | ||||||
|  |         super(SuperLinear, self).__init__() | ||||||
|  |  | ||||||
|  |         # the raw input args | ||||||
|  |         self._in_features = in_features | ||||||
|  |         self._out_features = out_features | ||||||
|  |         self._bias = bias | ||||||
|  |         # weights to be optimized | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_super_weight", | ||||||
|  |             torch.nn.Parameter(torch.Tensor(self.out_features, self.in_features)), | ||||||
|  |         ) | ||||||
|  |         if self.bias: | ||||||
|  |             self.register_parameter( | ||||||
|  |                 "_super_bias", torch.nn.Parameter(torch.Tensor(self.out_features)) | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             self.register_parameter("_super_bias", None) | ||||||
|  |         self.reset_parameters() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_features(self): | ||||||
|  |         return spaces.get_max(self._in_features) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def out_features(self): | ||||||
|  |         return spaces.get_max(self._out_features) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def bias(self): | ||||||
|  |         return spaces.has_categorical(self._bias, True) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |             root_node.append( | ||||||
|  |                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |             root_node.append( | ||||||
|  |                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|  |         if not spaces.is_determined(self._bias): | ||||||
|  |             root_node.append("_bias", self._bias.abstract(reuse_last=True)) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def reset_parameters(self) -> None: | ||||||
|  |         nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) | ||||||
|  |         if self.bias: | ||||||
|  |             fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight) | ||||||
|  |             bound = 1 / math.sqrt(fan_in) | ||||||
|  |             nn.init.uniform_(self._super_bias, -bound, bound) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |             expected_input_dim = self.abstract_child["_in_features"].value | ||||||
|  |         else: | ||||||
|  |             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||||
|  |         if input.size(-1) != expected_input_dim: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Expect the input dim of {:} instead of {:}".format( | ||||||
|  |                     expected_input_dim, input.size(-1) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         # create the weight matrix | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |             out_dim = self.abstract_child["_out_features"].value | ||||||
|  |         else: | ||||||
|  |             out_dim = spaces.get_determined_value(self._out_features) | ||||||
|  |         candidate_weight = self._super_weight[:out_dim, :expected_input_dim] | ||||||
|  |         # create the bias matrix | ||||||
|  |         if not spaces.is_determined(self._bias): | ||||||
|  |             if self.abstract_child["_bias"].value: | ||||||
|  |                 candidate_bias = self._super_bias[:out_dim] | ||||||
|  |             else: | ||||||
|  |                 candidate_bias = None | ||||||
|  |         else: | ||||||
|  |             if spaces.get_determined_value(self._bias): | ||||||
|  |                 candidate_bias = self._super_bias[:out_dim] | ||||||
|  |             else: | ||||||
|  |                 candidate_bias = None | ||||||
|  |         return F.linear(input, candidate_weight, candidate_bias) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return F.linear(input, self._super_weight, self._super_bias) | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "in_features={:}, out_features={:}, bias={:}".format( | ||||||
|  |             self._in_features, self._out_features, self._bias | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def forward_with_container(self, input, container, prefix=[]): | ||||||
|  |         super_weight_name = ".".join(prefix + ["_super_weight"]) | ||||||
|  |         super_weight = container.query(super_weight_name) | ||||||
|  |         super_bias_name = ".".join(prefix + ["_super_bias"]) | ||||||
|  |         if container.has(super_bias_name): | ||||||
|  |             super_bias = container.query(super_bias_name) | ||||||
|  |         else: | ||||||
|  |             super_bias = None | ||||||
|  |         return F.linear(input, super_weight, super_bias) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperMLPv1(SuperModule): | ||||||
|  |     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_features: IntSpaceType, | ||||||
|  |         hidden_features: IntSpaceType, | ||||||
|  |         out_features: IntSpaceType, | ||||||
|  |         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||||
|  |         drop: Optional[float] = None, | ||||||
|  |     ): | ||||||
|  |         super(SuperMLPv1, self).__init__() | ||||||
|  |         self._in_features = in_features | ||||||
|  |         self._hidden_features = hidden_features | ||||||
|  |         self._out_features = out_features | ||||||
|  |         self._drop_rate = drop | ||||||
|  |         self.fc1 = SuperLinear(in_features, hidden_features) | ||||||
|  |         self.act = act_layer() | ||||||
|  |         self.fc2 = SuperLinear(hidden_features, out_features) | ||||||
|  |         self.drop = nn.Dropout(drop or 0.0) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         space_fc1 = self.fc1.abstract_search_space | ||||||
|  |         space_fc2 = self.fc2.abstract_search_space | ||||||
|  |         if not spaces.is_determined(space_fc1): | ||||||
|  |             root_node.append("fc1", space_fc1) | ||||||
|  |         if not spaces.is_determined(space_fc2): | ||||||
|  |             root_node.append("fc2", space_fc2) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperMLPv1, self).apply_candidate(abstract_child) | ||||||
|  |         if "fc1" in abstract_child: | ||||||
|  |             self.fc1.apply_candidate(abstract_child["fc1"]) | ||||||
|  |         if "fc2" in abstract_child: | ||||||
|  |             self.fc2.apply_candidate(abstract_child["fc2"]) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         x = self.fc1(input) | ||||||
|  |         x = self.act(x) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         x = self.fc2(x) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||||
|  |             self._in_features, | ||||||
|  |             self._hidden_features, | ||||||
|  |             self._out_features, | ||||||
|  |             self._drop_rate, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperMLPv2(SuperModule): | ||||||
|  |     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_features: IntSpaceType, | ||||||
|  |         hidden_multiplier: IntSpaceType, | ||||||
|  |         out_features: IntSpaceType, | ||||||
|  |         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||||
|  |         drop: Optional[float] = None, | ||||||
|  |     ): | ||||||
|  |         super(SuperMLPv2, self).__init__() | ||||||
|  |         self._in_features = in_features | ||||||
|  |         self._hidden_multiplier = hidden_multiplier | ||||||
|  |         self._out_features = out_features | ||||||
|  |         self._drop_rate = drop | ||||||
|  |         self._params = nn.ParameterDict({}) | ||||||
|  |  | ||||||
|  |         self._create_linear( | ||||||
|  |             "fc1", self.in_features, int(self.in_features * self.hidden_multiplier) | ||||||
|  |         ) | ||||||
|  |         self._create_linear( | ||||||
|  |             "fc2", int(self.in_features * self.hidden_multiplier), self.out_features | ||||||
|  |         ) | ||||||
|  |         self.act = act_layer() | ||||||
|  |         self.drop = nn.Dropout(drop or 0.0) | ||||||
|  |         self.reset_parameters() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_features(self): | ||||||
|  |         return spaces.get_max(self._in_features) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def hidden_multiplier(self): | ||||||
|  |         return spaces.get_max(self._hidden_multiplier) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def out_features(self): | ||||||
|  |         return spaces.get_max(self._out_features) | ||||||
|  |  | ||||||
|  |     def _create_linear(self, name, inC, outC): | ||||||
|  |         self._params["{:}_super_weight".format(name)] = torch.nn.Parameter( | ||||||
|  |             torch.Tensor(outC, inC) | ||||||
|  |         ) | ||||||
|  |         self._params["{:}_super_bias".format(name)] = torch.nn.Parameter( | ||||||
|  |             torch.Tensor(outC) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def reset_parameters(self) -> None: | ||||||
|  |         nn.init.kaiming_uniform_(self._params["fc1_super_weight"], a=math.sqrt(5)) | ||||||
|  |         nn.init.kaiming_uniform_(self._params["fc2_super_weight"], a=math.sqrt(5)) | ||||||
|  |         fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | ||||||
|  |             self._params["fc1_super_weight"] | ||||||
|  |         ) | ||||||
|  |         bound = 1 / math.sqrt(fan_in) | ||||||
|  |         nn.init.uniform_(self._params["fc1_super_bias"], -bound, bound) | ||||||
|  |         fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | ||||||
|  |             self._params["fc2_super_weight"] | ||||||
|  |         ) | ||||||
|  |         bound = 1 / math.sqrt(fan_in) | ||||||
|  |         nn.init.uniform_(self._params["fc2_super_bias"], -bound, bound) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |             root_node.append( | ||||||
|  |                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|  |         if not spaces.is_determined(self._hidden_multiplier): | ||||||
|  |             root_node.append( | ||||||
|  |                 "_hidden_multiplier", self._hidden_multiplier.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |             root_node.append( | ||||||
|  |                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||||
|  |             ) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |             expected_input_dim = self.abstract_child["_in_features"].value | ||||||
|  |         else: | ||||||
|  |             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||||
|  |         if input.size(-1) != expected_input_dim: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Expect the input dim of {:} instead of {:}".format( | ||||||
|  |                     expected_input_dim, input.size(-1) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         # create the weight and bias matrix for fc1 | ||||||
|  |         if not spaces.is_determined(self._hidden_multiplier): | ||||||
|  |             hmul = self.abstract_child["_hidden_multiplier"].value * expected_input_dim | ||||||
|  |         else: | ||||||
|  |             hmul = spaces.get_determined_value(self._hidden_multiplier) | ||||||
|  |         hidden_dim = int(expected_input_dim * hmul) | ||||||
|  |         _fc1_weight = self._params["fc1_super_weight"][:hidden_dim, :expected_input_dim] | ||||||
|  |         _fc1_bias = self._params["fc1_super_bias"][:hidden_dim] | ||||||
|  |         x = F.linear(input, _fc1_weight, _fc1_bias) | ||||||
|  |         x = self.act(x) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         # create the weight and bias matrix for fc2 | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |             out_dim = self.abstract_child["_out_features"].value | ||||||
|  |         else: | ||||||
|  |             out_dim = spaces.get_determined_value(self._out_features) | ||||||
|  |         _fc2_weight = self._params["fc2_super_weight"][:out_dim, :hidden_dim] | ||||||
|  |         _fc2_bias = self._params["fc2_super_bias"][:out_dim] | ||||||
|  |         x = F.linear(x, _fc2_weight, _fc2_bias) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         x = F.linear( | ||||||
|  |             input, self._params["fc1_super_weight"], self._params["fc1_super_bias"] | ||||||
|  |         ) | ||||||
|  |         x = self.act(x) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         x = F.linear( | ||||||
|  |             x, self._params["fc2_super_weight"], self._params["fc2_super_bias"] | ||||||
|  |         ) | ||||||
|  |         x = self.drop(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "in_features={:}, hidden_multiplier={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||||
|  |             self._in_features, | ||||||
|  |             self._hidden_multiplier, | ||||||
|  |             self._out_features, | ||||||
|  |             self._drop_rate, | ||||||
|  |         ) | ||||||
| @@ -3,3 +3,5 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # The models in this folder is written with xlayers # | # The models in this folder is written with xlayers # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
|  | from .transformers import get_transformer | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||||
|  | ##################################################### | ||||||
| import math | import math | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional, Text, List | from typing import Optional, Text, List | ||||||
| @@ -10,186 +12,163 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| from xautodl import spaces | from xautodl import spaces | ||||||
| from xautodl.xlayers import trunc_normal_ | from xautodl import xlayers | ||||||
| from xautodl.xlayers import super_core | from xautodl.xlayers import weight_init | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | def pair(t): | ||||||
|  |     return t if isinstance(t, tuple) else (t, t) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_mul_specs(candidates, num): | def _init_weights(m): | ||||||
|     results = [] |     if isinstance(m, nn.Linear): | ||||||
|     for i in range(num): |         weight_init.trunc_normal_(m.weight, std=0.02) | ||||||
|         results.append(spaces.Categorical(*candidates)) |         if isinstance(m, nn.Linear) and m.bias is not None: | ||||||
|     return results |             nn.init.constant_(m.bias, 0) | ||||||
|  |     elif isinstance(m, xlayers.SuperLinear): | ||||||
|  |         weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||||
|  |         if m._super_bias is not None: | ||||||
|  |             nn.init.constant_(m._super_bias, 0) | ||||||
|  |     elif isinstance(m, xlayers.SuperLayerNorm1D): | ||||||
|  |         nn.init.constant_(m.weight, 1.0) | ||||||
|  |         nn.init.constant_(m.bias, 0) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_list_mul(num, multipler): | name2config = { | ||||||
|     results = [] |     "vit-base": dict( | ||||||
|     for i in range(1, num + 1): |         type="vit", | ||||||
|         results.append(i * multipler) |         image_size=256, | ||||||
|     return results |         patch_size=16, | ||||||
|  |         num_classes=1000, | ||||||
|  |         dim=768, | ||||||
|  |         depth=12, | ||||||
|  |         heads=12, | ||||||
|  |         dropout=0.1, | ||||||
|  |         emb_dropout=0.1, | ||||||
|  |     ), | ||||||
|  |     "vit-large": dict( | ||||||
|  |         type="vit", | ||||||
|  |         image_size=256, | ||||||
|  |         patch_size=16, | ||||||
|  |         num_classes=1000, | ||||||
|  |         dim=1024, | ||||||
|  |         depth=24, | ||||||
|  |         heads=16, | ||||||
|  |         dropout=0.1, | ||||||
|  |         emb_dropout=0.1, | ||||||
|  |     ), | ||||||
|  |     "vit-huge": dict( | ||||||
|  |         type="vit", | ||||||
|  |         image_size=256, | ||||||
|  |         patch_size=16, | ||||||
|  |         num_classes=1000, | ||||||
|  |         dim=1280, | ||||||
|  |         depth=32, | ||||||
|  |         heads=16, | ||||||
|  |         dropout=0.1, | ||||||
|  |         emb_dropout=0.1, | ||||||
|  |     ), | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| def _assert_types(x, expected_types): | class SuperViT(xlayers.SuperModule): | ||||||
|     if not isinstance(x, expected_types): |  | ||||||
|         raise TypeError( |  | ||||||
|             "The type [{:}] is expected to be {:}.".format(type(x), expected_types) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| DEFAULT_NET_CONFIG = None |  | ||||||
| _default_max_depth = 5 |  | ||||||
| DefaultSearchSpace = dict( |  | ||||||
|     d_feat=6, |  | ||||||
|     embed_dim=spaces.Categorical(*_get_list_mul(8, 16)), |  | ||||||
|     num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth), |  | ||||||
|     mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth), |  | ||||||
|     qkv_bias=True, |  | ||||||
|     pos_drop=0.0, |  | ||||||
|     other_drop=0.0, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperTransformer(super_core.SuperModule): |  | ||||||
|     """The super model for transformer.""" |     """The super model for transformer.""" | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         d_feat: int = 6, |         image_size, | ||||||
|         embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"], |         patch_size, | ||||||
|         num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], |         num_classes, | ||||||
|         mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ |         dim, | ||||||
|             "mlp_hidden_multipliers" |         depth, | ||||||
|         ], |         heads, | ||||||
|         qkv_bias: bool = DefaultSearchSpace["qkv_bias"], |         mlp_multiplier=4, | ||||||
|         pos_drop: float = DefaultSearchSpace["pos_drop"], |         channels=3, | ||||||
|         other_drop: float = DefaultSearchSpace["other_drop"], |         dropout=0.0, | ||||||
|         max_seq_len: int = 65, |         emb_dropout=0.0, | ||||||
|     ): |     ): | ||||||
|         super(SuperTransformer, self).__init__() |         super(SuperViT, self).__init__() | ||||||
|         self._embed_dim = embed_dim |         image_height, image_width = pair(image_size) | ||||||
|         self._num_heads = num_heads |         patch_height, patch_width = pair(patch_size) | ||||||
|         self._mlp_hidden_multipliers = mlp_hidden_multipliers |  | ||||||
|  |  | ||||||
|         # the stem part |         if image_height % patch_height != 0 or image_width % patch_width != 0: | ||||||
|         self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim) |             raise ValueError("Image dimensions must be divisible by the patch size.") | ||||||
|         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) |  | ||||||
|         self.pos_embed = super_core.SuperPositionalEncoder( |         num_patches = (image_height // patch_height) * (image_width // patch_width) | ||||||
|             d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop |         patch_dim = channels * patch_height * patch_width | ||||||
|  |         self.to_patch_embedding = xlayers.SuperSequential( | ||||||
|  |             xlayers.SuperReArrange( | ||||||
|  |                 "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | ||||||
|  |                 p1=patch_height, | ||||||
|  |                 p2=patch_width, | ||||||
|  |             ), | ||||||
|  |             xlayers.SuperLinear(patch_dim, dim), | ||||||
|         ) |         ) | ||||||
|         # build the transformer encode layers -->> check params |  | ||||||
|         _assert_types(num_heads, (tuple, list)) |         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||||
|         _assert_types(mlp_hidden_multipliers, (tuple, list)) |         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||||
|         assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format( |         self.dropout = nn.Dropout(emb_dropout) | ||||||
|             len(num_heads), len(mlp_hidden_multipliers) |  | ||||||
|         ) |         # build the transformer encode layers | ||||||
|         # build the transformer encode layers -->> backbone |  | ||||||
|         layers = [] |         layers = [] | ||||||
|         for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers): |         for ilayer in range(depth): | ||||||
|             layer = super_core.SuperTransformerEncoderLayer( |             layers.append( | ||||||
|                 embed_dim, |                 xlayers.SuperTransformerEncoderLayer( | ||||||
|                 num_head, |                     dim, heads, False, mlp_multiplier, dropout | ||||||
|                 qkv_bias, |  | ||||||
|                 mlp_hidden_multiplier, |  | ||||||
|                 other_drop, |  | ||||||
|                 ) |                 ) | ||||||
|             layers.append(layer) |  | ||||||
|         self.backbone = super_core.SuperSequential(*layers) |  | ||||||
|  |  | ||||||
|         # the regression head |  | ||||||
|         self.head = super_core.SuperSequential( |  | ||||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) |  | ||||||
|             ) |             ) | ||||||
|         trunc_normal_(self.cls_token, std=0.02) |         self.backbone = xlayers.SuperSequential(*layers) | ||||||
|         self.apply(self._init_weights) |         self.cls_head = xlayers.SuperSequential( | ||||||
|  |             xlayers.SuperLayerNorm1D(dim), xlayers.SuperLinear(dim, num_classes) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @property |         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||||
|     def embed_dim(self): |         self.apply(_init_weights) | ||||||
|         return spaces.get_max(self._embed_dim) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         root_node = spaces.VirtualNode(id(self)) |         raise NotImplementedError | ||||||
|         if not spaces.is_determined(self._embed_dim): |  | ||||||
|             root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) |  | ||||||
|         xdict = dict( |  | ||||||
|             input_embed=self.input_embed.abstract_search_space, |  | ||||||
|             pos_embed=self.pos_embed.abstract_search_space, |  | ||||||
|             backbone=self.backbone.abstract_search_space, |  | ||||||
|             head=self.head.abstract_search_space, |  | ||||||
|         ) |  | ||||||
|         for key, space in xdict.items(): |  | ||||||
|             if not spaces.is_determined(space): |  | ||||||
|                 root_node.append(key, space) |  | ||||||
|         return root_node |  | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|         super(SuperTransformer, self).apply_candidate(abstract_child) |         super(SuperViT, self).apply_candidate(abstract_child) | ||||||
|         xkeys = ("input_embed", "pos_embed", "backbone", "head") |         raise NotImplementedError | ||||||
|         for key in xkeys: |  | ||||||
|             if key in abstract_child: |  | ||||||
|                 getattr(self, key).apply_candidate(abstract_child[key]) |  | ||||||
|  |  | ||||||
|     def _init_weights(self, m): |  | ||||||
|         if isinstance(m, nn.Linear): |  | ||||||
|             trunc_normal_(m.weight, std=0.02) |  | ||||||
|             if isinstance(m, nn.Linear) and m.bias is not None: |  | ||||||
|                 nn.init.constant_(m.bias, 0) |  | ||||||
|         elif isinstance(m, super_core.SuperLinear): |  | ||||||
|             trunc_normal_(m._super_weight, std=0.02) |  | ||||||
|             if m._super_bias is not None: |  | ||||||
|                 nn.init.constant_(m._super_bias, 0) |  | ||||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): |  | ||||||
|             nn.init.constant_(m.weight, 1.0) |  | ||||||
|             nn.init.constant_(m.bias, 0) |  | ||||||
|  |  | ||||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         batch, flatten_size = input.shape |         raise NotImplementedError | ||||||
|         feats = self.input_embed(input)  # batch * 60 * 64 |  | ||||||
|         if not spaces.is_determined(self._embed_dim): |  | ||||||
|             embed_dim = self.abstract_child["_embed_dim"].value |  | ||||||
|         else: |  | ||||||
|             embed_dim = spaces.get_determined_value(self._embed_dim) |  | ||||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) |  | ||||||
|         cls_tokens = F.interpolate( |  | ||||||
|             cls_tokens, size=(embed_dim), mode="linear", align_corners=True |  | ||||||
|         ) |  | ||||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) |  | ||||||
|         feats_w_tp = self.pos_embed(feats_w_ct) |  | ||||||
|         xfeats = self.backbone(feats_w_tp) |  | ||||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token |  | ||||||
|         predicts = self.head(xfeats).squeeze(-1) |  | ||||||
|         return predicts |  | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         batch, flatten_size = input.shape |         tensors = self.to_patch_embedding(input) | ||||||
|         feats = self.input_embed(input)  # batch * 60 * 64 |         batch, seq, _ = tensors.shape | ||||||
|  |  | ||||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) |         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) |         feats = torch.cat((cls_tokens, tensors), dim=1) | ||||||
|         feats_w_tp = self.pos_embed(feats_w_ct) |         feats = feats + self.pos_embedding[:, : seq + 1, :] | ||||||
|         xfeats = self.backbone(feats_w_tp) |         feats = self.dropout(feats) | ||||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token |  | ||||||
|         predicts = self.head(xfeats).squeeze(-1) |         feats = self.backbone(feats) | ||||||
|         return predicts |  | ||||||
|  |         x = feats[:, 0]  # the features for cls-token | ||||||
|  |  | ||||||
|  |         return self.cls_head(x) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_transformer(config): | def get_transformer(config): | ||||||
|     if config is None: |     if isinstance(config, str) and config.lower() in name2config: | ||||||
|         return SuperTransformer(6) |         config = name2config[config.lower()] | ||||||
|     if not isinstance(config, dict): |     if not isinstance(config, dict): | ||||||
|         raise ValueError("Invalid Configuration: {:}".format(config)) |         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||||
|     name = config.get("name", "basic") |     model_type = config.get("type", "vit").lower() | ||||||
|     if name == "basic": |     if model_type == "vit": | ||||||
|         model = SuperTransformer( |         model = SuperViT( | ||||||
|             d_feat=config.get("d_feat"), |             image_size=config.get("image_size"), | ||||||
|             embed_dim=config.get("embed_dim"), |             patch_size=config.get("patch_size"), | ||||||
|             num_heads=config.get("num_heads"), |             num_classes=config.get("num_classes"), | ||||||
|             mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"), |             dim=config.get("dim"), | ||||||
|             qkv_bias=config.get("qkv_bias"), |             depth=config.get("depth"), | ||||||
|             pos_drop=config.get("pos_drop"), |             heads=config.get("heads"), | ||||||
|             other_drop=config.get("other_drop"), |             dropout=config.get("dropout"), | ||||||
|  |             emb_dropout=config.get("emb_dropout"), | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown model name: {:}".format(name)) |         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||||
|     return model |     return model | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user