Update ViT
This commit is contained in:
		
							
								
								
									
										29
									
								
								tests/test_super_vit.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_super_vit.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_vit.py -s               # | ||||
| ##################################################### | ||||
| import sys | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
| from xautodl.xmodels import transformers | ||||
| from xautodl.utils.flop_benchmark import count_parameters | ||||
|  | ||||
| class TestSuperViT(unittest.TestCase): | ||||
|     """Test the super re-arrange layer.""" | ||||
|  | ||||
|     def test_super_vit(self): | ||||
|         model = transformers.get_transformer("vit-base") | ||||
|         tensor = torch.rand((16, 3, 256, 256)) | ||||
|         print("The tensor shape: {:}".format(tensor.shape)) | ||||
|         print(model) | ||||
|         outs = model(tensor) | ||||
|         print("The output tensor shape: {:}".format(outs.shape)) | ||||
|  | ||||
|     def test_model_size(self): | ||||
|         name2config = transformers.name2config | ||||
|         for name, config in name2config.items(): | ||||
|             model = transformers.get_transformer(config) | ||||
|             size = count_parameters(model, "mb", True) | ||||
|             print('{:10s} : size={:.2f}MB'.format(name, size)) | ||||
							
								
								
									
										319
									
								
								xautodl/xlayers/super_mlp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								xautodl/xlayers/super_mlp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,319 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| from xautodl import spaces | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
|  | ||||
|  | ||||
| class SuperLinear(SuperModule): | ||||
|     """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         bias: BoolSpaceType = True, | ||||
|     ) -> None: | ||||
|         super(SuperLinear, self).__init__() | ||||
|  | ||||
|         # the raw input args | ||||
|         self._in_features = in_features | ||||
|         self._out_features = out_features | ||||
|         self._bias = bias | ||||
|         # weights to be optimized | ||||
|         self.register_parameter( | ||||
|             "_super_weight", | ||||
|             torch.nn.Parameter(torch.Tensor(self.out_features, self.in_features)), | ||||
|         ) | ||||
|         if self.bias: | ||||
|             self.register_parameter( | ||||
|                 "_super_bias", torch.nn.Parameter(torch.Tensor(self.out_features)) | ||||
|             ) | ||||
|         else: | ||||
|             self.register_parameter("_super_bias", None) | ||||
|         self.reset_parameters() | ||||
|  | ||||
|     @property | ||||
|     def in_features(self): | ||||
|         return spaces.get_max(self._in_features) | ||||
|  | ||||
|     @property | ||||
|     def out_features(self): | ||||
|         return spaces.get_max(self._out_features) | ||||
|  | ||||
|     @property | ||||
|     def bias(self): | ||||
|         return spaces.has_categorical(self._bias, True) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append( | ||||
|                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append( | ||||
|                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             root_node.append("_bias", self._bias.abstract(reuse_last=True)) | ||||
|         return root_node | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
|         nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) | ||||
|         if self.bias: | ||||
|             fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight) | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             nn.init.uniform_(self._super_bias, -bound, bound) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             expected_input_dim = self.abstract_child["_in_features"].value | ||||
|         else: | ||||
|             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||
|         if input.size(-1) != expected_input_dim: | ||||
|             raise ValueError( | ||||
|                 "Expect the input dim of {:} instead of {:}".format( | ||||
|                     expected_input_dim, input.size(-1) | ||||
|                 ) | ||||
|             ) | ||||
|         # create the weight matrix | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             out_dim = self.abstract_child["_out_features"].value | ||||
|         else: | ||||
|             out_dim = spaces.get_determined_value(self._out_features) | ||||
|         candidate_weight = self._super_weight[:out_dim, :expected_input_dim] | ||||
|         # create the bias matrix | ||||
|         if not spaces.is_determined(self._bias): | ||||
|             if self.abstract_child["_bias"].value: | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         else: | ||||
|             if spaces.get_determined_value(self._bias): | ||||
|                 candidate_bias = self._super_bias[:out_dim] | ||||
|             else: | ||||
|                 candidate_bias = None | ||||
|         return F.linear(input, candidate_weight, candidate_bias) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.linear(input, self._super_weight, self._super_bias) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, out_features={:}, bias={:}".format( | ||||
|             self._in_features, self._out_features, self._bias | ||||
|         ) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         super_weight_name = ".".join(prefix + ["_super_weight"]) | ||||
|         super_weight = container.query(super_weight_name) | ||||
|         super_bias_name = ".".join(prefix + ["_super_bias"]) | ||||
|         if container.has(super_bias_name): | ||||
|             super_bias = container.query(super_bias_name) | ||||
|         else: | ||||
|             super_bias = None | ||||
|         return F.linear(input, super_weight, super_bias) | ||||
|  | ||||
|  | ||||
| class SuperMLPv1(SuperModule): | ||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features: IntSpaceType, | ||||
|         hidden_features: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(SuperMLPv1, self).__init__() | ||||
|         self._in_features = in_features | ||||
|         self._hidden_features = hidden_features | ||||
|         self._out_features = out_features | ||||
|         self._drop_rate = drop | ||||
|         self.fc1 = SuperLinear(in_features, hidden_features) | ||||
|         self.act = act_layer() | ||||
|         self.fc2 = SuperLinear(hidden_features, out_features) | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         space_fc1 = self.fc1.abstract_search_space | ||||
|         space_fc2 = self.fc2.abstract_search_space | ||||
|         if not spaces.is_determined(space_fc1): | ||||
|             root_node.append("fc1", space_fc1) | ||||
|         if not spaces.is_determined(space_fc2): | ||||
|             root_node.append("fc2", space_fc2) | ||||
|         return root_node | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperMLPv1, self).apply_candidate(abstract_child) | ||||
|         if "fc1" in abstract_child: | ||||
|             self.fc1.apply_candidate(abstract_child["fc1"]) | ||||
|         if "fc2" in abstract_child: | ||||
|             self.fc2.apply_candidate(abstract_child["fc2"]) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         x = self.fc1(input) | ||||
|         x = self.act(x) | ||||
|         x = self.drop(x) | ||||
|         x = self.fc2(x) | ||||
|         x = self.drop(x) | ||||
|         return x | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||
|             self._in_features, | ||||
|             self._hidden_features, | ||||
|             self._out_features, | ||||
|             self._drop_rate, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperMLPv2(SuperModule): | ||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features: IntSpaceType, | ||||
|         hidden_multiplier: IntSpaceType, | ||||
|         out_features: IntSpaceType, | ||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||
|         drop: Optional[float] = None, | ||||
|     ): | ||||
|         super(SuperMLPv2, self).__init__() | ||||
|         self._in_features = in_features | ||||
|         self._hidden_multiplier = hidden_multiplier | ||||
|         self._out_features = out_features | ||||
|         self._drop_rate = drop | ||||
|         self._params = nn.ParameterDict({}) | ||||
|  | ||||
|         self._create_linear( | ||||
|             "fc1", self.in_features, int(self.in_features * self.hidden_multiplier) | ||||
|         ) | ||||
|         self._create_linear( | ||||
|             "fc2", int(self.in_features * self.hidden_multiplier), self.out_features | ||||
|         ) | ||||
|         self.act = act_layer() | ||||
|         self.drop = nn.Dropout(drop or 0.0) | ||||
|         self.reset_parameters() | ||||
|  | ||||
|     @property | ||||
|     def in_features(self): | ||||
|         return spaces.get_max(self._in_features) | ||||
|  | ||||
|     @property | ||||
|     def hidden_multiplier(self): | ||||
|         return spaces.get_max(self._hidden_multiplier) | ||||
|  | ||||
|     @property | ||||
|     def out_features(self): | ||||
|         return spaces.get_max(self._out_features) | ||||
|  | ||||
|     def _create_linear(self, name, inC, outC): | ||||
|         self._params["{:}_super_weight".format(name)] = torch.nn.Parameter( | ||||
|             torch.Tensor(outC, inC) | ||||
|         ) | ||||
|         self._params["{:}_super_bias".format(name)] = torch.nn.Parameter( | ||||
|             torch.Tensor(outC) | ||||
|         ) | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
|         nn.init.kaiming_uniform_(self._params["fc1_super_weight"], a=math.sqrt(5)) | ||||
|         nn.init.kaiming_uniform_(self._params["fc2_super_weight"], a=math.sqrt(5)) | ||||
|         fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | ||||
|             self._params["fc1_super_weight"] | ||||
|         ) | ||||
|         bound = 1 / math.sqrt(fan_in) | ||||
|         nn.init.uniform_(self._params["fc1_super_bias"], -bound, bound) | ||||
|         fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | ||||
|             self._params["fc2_super_weight"] | ||||
|         ) | ||||
|         bound = 1 / math.sqrt(fan_in) | ||||
|         nn.init.uniform_(self._params["fc2_super_bias"], -bound, bound) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             root_node.append( | ||||
|                 "_in_features", self._in_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._hidden_multiplier): | ||||
|             root_node.append( | ||||
|                 "_hidden_multiplier", self._hidden_multiplier.abstract(reuse_last=True) | ||||
|             ) | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             root_node.append( | ||||
|                 "_out_features", self._out_features.abstract(reuse_last=True) | ||||
|             ) | ||||
|         return root_node | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         if not spaces.is_determined(self._in_features): | ||||
|             expected_input_dim = self.abstract_child["_in_features"].value | ||||
|         else: | ||||
|             expected_input_dim = spaces.get_determined_value(self._in_features) | ||||
|         if input.size(-1) != expected_input_dim: | ||||
|             raise ValueError( | ||||
|                 "Expect the input dim of {:} instead of {:}".format( | ||||
|                     expected_input_dim, input.size(-1) | ||||
|                 ) | ||||
|             ) | ||||
|         # create the weight and bias matrix for fc1 | ||||
|         if not spaces.is_determined(self._hidden_multiplier): | ||||
|             hmul = self.abstract_child["_hidden_multiplier"].value * expected_input_dim | ||||
|         else: | ||||
|             hmul = spaces.get_determined_value(self._hidden_multiplier) | ||||
|         hidden_dim = int(expected_input_dim * hmul) | ||||
|         _fc1_weight = self._params["fc1_super_weight"][:hidden_dim, :expected_input_dim] | ||||
|         _fc1_bias = self._params["fc1_super_bias"][:hidden_dim] | ||||
|         x = F.linear(input, _fc1_weight, _fc1_bias) | ||||
|         x = self.act(x) | ||||
|         x = self.drop(x) | ||||
|         # create the weight and bias matrix for fc2 | ||||
|         if not spaces.is_determined(self._out_features): | ||||
|             out_dim = self.abstract_child["_out_features"].value | ||||
|         else: | ||||
|             out_dim = spaces.get_determined_value(self._out_features) | ||||
|         _fc2_weight = self._params["fc2_super_weight"][:out_dim, :hidden_dim] | ||||
|         _fc2_bias = self._params["fc2_super_bias"][:out_dim] | ||||
|         x = F.linear(x, _fc2_weight, _fc2_bias) | ||||
|         x = self.drop(x) | ||||
|         return x | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         x = F.linear( | ||||
|             input, self._params["fc1_super_weight"], self._params["fc1_super_bias"] | ||||
|         ) | ||||
|         x = self.act(x) | ||||
|         x = self.drop(x) | ||||
|         x = F.linear( | ||||
|             x, self._params["fc2_super_weight"], self._params["fc2_super_bias"] | ||||
|         ) | ||||
|         x = self.drop(x) | ||||
|         return x | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "in_features={:}, hidden_multiplier={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( | ||||
|             self._in_features, | ||||
|             self._hidden_multiplier, | ||||
|             self._out_features, | ||||
|             self._drop_rate, | ||||
|         ) | ||||
| @@ -3,3 +3,5 @@ | ||||
| ##################################################### | ||||
| # The models in this folder is written with xlayers # | ||||
| ##################################################### | ||||
|  | ||||
| from .transformers import get_transformer | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||
| ##################################################### | ||||
| import math | ||||
| from functools import partial | ||||
| from typing import Optional, Text, List | ||||
| @@ -10,186 +12,163 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.xlayers import super_core | ||||
| from xautodl import xlayers | ||||
| from xautodl.xlayers import weight_init | ||||
|  | ||||
|  | ||||
| __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | ||||
| def pair(t): | ||||
|     return t if isinstance(t, tuple) else (t, t) | ||||
|  | ||||
|  | ||||
| def _get_mul_specs(candidates, num): | ||||
|     results = [] | ||||
|     for i in range(num): | ||||
|         results.append(spaces.Categorical(*candidates)) | ||||
|     return results | ||||
| def _init_weights(m): | ||||
|     if isinstance(m, nn.Linear): | ||||
|         weight_init.trunc_normal_(m.weight, std=0.02) | ||||
|         if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|     elif isinstance(m, xlayers.SuperLinear): | ||||
|         weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||
|         if m._super_bias is not None: | ||||
|             nn.init.constant_(m._super_bias, 0) | ||||
|     elif isinstance(m, xlayers.SuperLayerNorm1D): | ||||
|         nn.init.constant_(m.weight, 1.0) | ||||
|         nn.init.constant_(m.bias, 0) | ||||
|  | ||||
|  | ||||
| def _get_list_mul(num, multipler): | ||||
|     results = [] | ||||
|     for i in range(1, num + 1): | ||||
|         results.append(i * multipler) | ||||
|     return results | ||||
| name2config = { | ||||
|     "vit-base": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         patch_size=16, | ||||
|         num_classes=1000, | ||||
|         dim=768, | ||||
|         depth=12, | ||||
|         heads=12, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|     ), | ||||
|     "vit-large": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         patch_size=16, | ||||
|         num_classes=1000, | ||||
|         dim=1024, | ||||
|         depth=24, | ||||
|         heads=16, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|     ), | ||||
|     "vit-huge": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         patch_size=16, | ||||
|         num_classes=1000, | ||||
|         dim=1280, | ||||
|         depth=32, | ||||
|         heads=16, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| def _assert_types(x, expected_types): | ||||
|     if not isinstance(x, expected_types): | ||||
|         raise TypeError( | ||||
|             "The type [{:}] is expected to be {:}.".format(type(x), expected_types) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| DEFAULT_NET_CONFIG = None | ||||
| _default_max_depth = 5 | ||||
| DefaultSearchSpace = dict( | ||||
|     d_feat=6, | ||||
|     embed_dim=spaces.Categorical(*_get_list_mul(8, 16)), | ||||
|     num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth), | ||||
|     mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth), | ||||
|     qkv_bias=True, | ||||
|     pos_drop=0.0, | ||||
|     other_drop=0.0, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class SuperTransformer(super_core.SuperModule): | ||||
| class SuperViT(xlayers.SuperModule): | ||||
|     """The super model for transformer.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         d_feat: int = 6, | ||||
|         embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"], | ||||
|         num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], | ||||
|         mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ | ||||
|             "mlp_hidden_multipliers" | ||||
|         ], | ||||
|         qkv_bias: bool = DefaultSearchSpace["qkv_bias"], | ||||
|         pos_drop: float = DefaultSearchSpace["pos_drop"], | ||||
|         other_drop: float = DefaultSearchSpace["other_drop"], | ||||
|         max_seq_len: int = 65, | ||||
|         image_size, | ||||
|         patch_size, | ||||
|         num_classes, | ||||
|         dim, | ||||
|         depth, | ||||
|         heads, | ||||
|         mlp_multiplier=4, | ||||
|         channels=3, | ||||
|         dropout=0.0, | ||||
|         emb_dropout=0.0, | ||||
|     ): | ||||
|         super(SuperTransformer, self).__init__() | ||||
|         self._embed_dim = embed_dim | ||||
|         self._num_heads = num_heads | ||||
|         self._mlp_hidden_multipliers = mlp_hidden_multipliers | ||||
|         super(SuperViT, self).__init__() | ||||
|         image_height, image_width = pair(image_size) | ||||
|         patch_height, patch_width = pair(patch_size) | ||||
|  | ||||
|         # the stem part | ||||
|         self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim) | ||||
|         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | ||||
|         self.pos_embed = super_core.SuperPositionalEncoder( | ||||
|             d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop | ||||
|         if image_height % patch_height != 0 or image_width % patch_width != 0: | ||||
|             raise ValueError("Image dimensions must be divisible by the patch size.") | ||||
|  | ||||
|         num_patches = (image_height // patch_height) * (image_width // patch_width) | ||||
|         patch_dim = channels * patch_height * patch_width | ||||
|         self.to_patch_embedding = xlayers.SuperSequential( | ||||
|             xlayers.SuperReArrange( | ||||
|                 "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | ||||
|                 p1=patch_height, | ||||
|                 p2=patch_width, | ||||
|             ), | ||||
|             xlayers.SuperLinear(patch_dim, dim), | ||||
|         ) | ||||
|         # build the transformer encode layers -->> check params | ||||
|         _assert_types(num_heads, (tuple, list)) | ||||
|         _assert_types(mlp_hidden_multipliers, (tuple, list)) | ||||
|         assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format( | ||||
|             len(num_heads), len(mlp_hidden_multipliers) | ||||
|         ) | ||||
|         # build the transformer encode layers -->> backbone | ||||
|  | ||||
|         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||
|         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||
|         self.dropout = nn.Dropout(emb_dropout) | ||||
|  | ||||
|         # build the transformer encode layers | ||||
|         layers = [] | ||||
|         for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers): | ||||
|             layer = super_core.SuperTransformerEncoderLayer( | ||||
|                 embed_dim, | ||||
|                 num_head, | ||||
|                 qkv_bias, | ||||
|                 mlp_hidden_multiplier, | ||||
|                 other_drop, | ||||
|         for ilayer in range(depth): | ||||
|             layers.append( | ||||
|                 xlayers.SuperTransformerEncoderLayer( | ||||
|                     dim, heads, False, mlp_multiplier, dropout | ||||
|                 ) | ||||
|             ) | ||||
|             layers.append(layer) | ||||
|         self.backbone = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         # the regression head | ||||
|         self.head = super_core.SuperSequential( | ||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | ||||
|         self.backbone = xlayers.SuperSequential(*layers) | ||||
|         self.cls_head = xlayers.SuperSequential( | ||||
|             xlayers.SuperLayerNorm1D(dim), xlayers.SuperLinear(dim, num_classes) | ||||
|         ) | ||||
|         trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(self._init_weights) | ||||
|  | ||||
|     @property | ||||
|     def embed_dim(self): | ||||
|         return spaces.get_max(self._embed_dim) | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(_init_weights) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         if not spaces.is_determined(self._embed_dim): | ||||
|             root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) | ||||
|         xdict = dict( | ||||
|             input_embed=self.input_embed.abstract_search_space, | ||||
|             pos_embed=self.pos_embed.abstract_search_space, | ||||
|             backbone=self.backbone.abstract_search_space, | ||||
|             head=self.head.abstract_search_space, | ||||
|         ) | ||||
|         for key, space in xdict.items(): | ||||
|             if not spaces.is_determined(space): | ||||
|                 root_node.append(key, space) | ||||
|         return root_node | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperTransformer, self).apply_candidate(abstract_child) | ||||
|         xkeys = ("input_embed", "pos_embed", "backbone", "head") | ||||
|         for key in xkeys: | ||||
|             if key in abstract_child: | ||||
|                 getattr(self, key).apply_candidate(abstract_child[key]) | ||||
|  | ||||
|     def _init_weights(self, m): | ||||
|         if isinstance(m, nn.Linear): | ||||
|             trunc_normal_(m.weight, std=0.02) | ||||
|             if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|                 nn.init.constant_(m.bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLinear): | ||||
|             trunc_normal_(m._super_weight, std=0.02) | ||||
|             if m._super_bias is not None: | ||||
|                 nn.init.constant_(m._super_bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||
|             nn.init.constant_(m.weight, 1.0) | ||||
|             nn.init.constant_(m.bias, 0) | ||||
|         super(SuperViT, self).apply_candidate(abstract_child) | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         batch, flatten_size = input.shape | ||||
|         feats = self.input_embed(input)  # batch * 60 * 64 | ||||
|         if not spaces.is_determined(self._embed_dim): | ||||
|             embed_dim = self.abstract_child["_embed_dim"].value | ||||
|         else: | ||||
|             embed_dim = spaces.get_determined_value(self._embed_dim) | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         cls_tokens = F.interpolate( | ||||
|             cls_tokens, size=(embed_dim), mode="linear", align_corners=True | ||||
|         ) | ||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||
|         feats_w_tp = self.pos_embed(feats_w_ct) | ||||
|         xfeats = self.backbone(feats_w_tp) | ||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||
|         predicts = self.head(xfeats).squeeze(-1) | ||||
|         return predicts | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         batch, flatten_size = input.shape | ||||
|         feats = self.input_embed(input)  # batch * 60 * 64 | ||||
|         tensors = self.to_patch_embedding(input) | ||||
|         batch, seq, _ = tensors.shape | ||||
|  | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||
|         feats_w_tp = self.pos_embed(feats_w_ct) | ||||
|         xfeats = self.backbone(feats_w_tp) | ||||
|         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||
|         predicts = self.head(xfeats).squeeze(-1) | ||||
|         return predicts | ||||
|         feats = torch.cat((cls_tokens, tensors), dim=1) | ||||
|         feats = feats + self.pos_embedding[:, : seq + 1, :] | ||||
|         feats = self.dropout(feats) | ||||
|  | ||||
|         feats = self.backbone(feats) | ||||
|  | ||||
|         x = feats[:, 0]  # the features for cls-token | ||||
|  | ||||
|         return self.cls_head(x) | ||||
|  | ||||
|  | ||||
| def get_transformer(config): | ||||
|     if config is None: | ||||
|         return SuperTransformer(6) | ||||
|     if isinstance(config, str) and config.lower() in name2config: | ||||
|         config = name2config[config.lower()] | ||||
|     if not isinstance(config, dict): | ||||
|         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||
|     name = config.get("name", "basic") | ||||
|     if name == "basic": | ||||
|         model = SuperTransformer( | ||||
|             d_feat=config.get("d_feat"), | ||||
|             embed_dim=config.get("embed_dim"), | ||||
|             num_heads=config.get("num_heads"), | ||||
|             mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"), | ||||
|             qkv_bias=config.get("qkv_bias"), | ||||
|             pos_drop=config.get("pos_drop"), | ||||
|             other_drop=config.get("other_drop"), | ||||
|     model_type = config.get("type", "vit").lower() | ||||
|     if model_type == "vit": | ||||
|         model = SuperViT( | ||||
|             image_size=config.get("image_size"), | ||||
|             patch_size=config.get("patch_size"), | ||||
|             num_classes=config.get("num_classes"), | ||||
|             dim=config.get("dim"), | ||||
|             depth=config.get("depth"), | ||||
|             heads=config.get("heads"), | ||||
|             dropout=config.get("dropout"), | ||||
|             emb_dropout=config.get("emb_dropout"), | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown model name: {:}".format(name)) | ||||
|         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||
|     return model | ||||
|   | ||||
		Reference in New Issue
	
	Block a user