Update SuperViT
This commit is contained in:
		| @@ -10,20 +10,31 @@ import torch | ||||
| from xautodl.xmodels import transformers | ||||
| from xautodl.utils.flop_benchmark import count_parameters | ||||
|  | ||||
|  | ||||
| class TestSuperViT(unittest.TestCase): | ||||
|     """Test the super re-arrange layer.""" | ||||
|  | ||||
|     def test_super_vit(self): | ||||
|         model = transformers.get_transformer("vit-base") | ||||
|         tensor = torch.rand((16, 3, 256, 256)) | ||||
|         model = transformers.get_transformer("vit-base-16") | ||||
|         tensor = torch.rand((16, 3, 224, 224)) | ||||
|         print("The tensor shape: {:}".format(tensor.shape)) | ||||
|         print(model) | ||||
|         # print(model) | ||||
|         outs = model(tensor) | ||||
|         print("The output tensor shape: {:}".format(outs.shape)) | ||||
|  | ||||
|     def test_model_size(self): | ||||
|     def test_imagenet(self): | ||||
|         name2config = transformers.name2config | ||||
|         print("There are {:} models in total.".format(len(name2config))) | ||||
|         for name, config in name2config.items(): | ||||
|             if "cifar" in name: | ||||
|                 tensor = torch.rand((16, 3, 32, 32)) | ||||
|             else: | ||||
|                 tensor = torch.rand((16, 3, 224, 224)) | ||||
|             model = transformers.get_transformer(config) | ||||
|             outs = model(tensor) | ||||
|             size = count_parameters(model, "mb", True) | ||||
|             print('{:10s} : size={:.2f}MB'.format(name, size)) | ||||
|             print( | ||||
|                 "{:10s} : size={:.2f}MB, out-shape: {:}".format( | ||||
|                     name, size, tuple(outs.shape) | ||||
|                 ) | ||||
|             ) | ||||
|   | ||||
| @@ -13,6 +13,7 @@ from xautodl import spaces | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
| from .super_dropout import SuperDropout, SuperDrop | ||||
| from .super_linear import SuperLinear | ||||
|  | ||||
|  | ||||
| @@ -22,7 +23,7 @@ class SuperSelfAttention(SuperModule): | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dim: IntSpaceType, | ||||
|         proj_dim: IntSpaceType, | ||||
|         proj_dim: Optional[IntSpaceType], | ||||
|         num_heads: IntSpaceType, | ||||
|         qkv_bias: BoolSpaceType = False, | ||||
|         attn_drop: Optional[float] = None, | ||||
| @@ -37,13 +38,17 @@ class SuperSelfAttention(SuperModule): | ||||
|         self._use_mask = use_mask | ||||
|         self._infinity = 1e9 | ||||
|  | ||||
|         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|         mul_head_dim = (input_dim // num_heads) * num_heads | ||||
|         self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||
|         self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||
|         self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||
|  | ||||
|         self.attn_drop = nn.Dropout(attn_drop or 0.0) | ||||
|         self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True) | ||||
|         if proj_dim is None: | ||||
|             self.proj = SuperLinear(input_dim, proj_dim) | ||||
|         self.proj_drop = nn.Dropout(proj_drop or 0.0) | ||||
|             self.proj_drop = SuperDropout(proj_drop or 0.0) | ||||
|         else: | ||||
|             self.proj = None | ||||
|  | ||||
|     @property | ||||
|     def num_heads(self): | ||||
| @@ -63,7 +68,6 @@ class SuperSelfAttention(SuperModule): | ||||
|         space_q = self.q_fc.abstract_search_space | ||||
|         space_k = self.k_fc.abstract_search_space | ||||
|         space_v = self.v_fc.abstract_search_space | ||||
|         space_proj = self.proj.abstract_search_space | ||||
|         if not spaces.is_determined(self._num_heads): | ||||
|             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) | ||||
|         if not spaces.is_determined(space_q): | ||||
| @@ -72,6 +76,8 @@ class SuperSelfAttention(SuperModule): | ||||
|             root_node.append("k_fc", space_k) | ||||
|         if not spaces.is_determined(space_v): | ||||
|             root_node.append("v_fc", space_v) | ||||
|         if self.proj is not None: | ||||
|             space_proj = self.proj.abstract_search_space | ||||
|             if not spaces.is_determined(space_proj): | ||||
|                 root_node.append("proj", space_proj) | ||||
|         return root_node | ||||
| @@ -121,18 +127,7 @@ class SuperSelfAttention(SuperModule): | ||||
|         attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * N | ||||
|         attn_v1 = self.attn_drop(attn_v1) | ||||
|         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||
|         if C == head_dim * num_head: | ||||
|             feats = feats_v1 | ||||
|         else:  # The channels can not be divided by num_head, the remainder forms an additional head | ||||
|             q_v2 = q[:, :, num_head * head_dim :] | ||||
|             k_v2 = k[:, :, num_head * head_dim :] | ||||
|             v_v2 = v[:, :, num_head * head_dim :] | ||||
|             attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) | ||||
|             attn_v2 = attn_v2.softmax(dim=-1) | ||||
|             attn_v2 = self.attn_drop(attn_v2) | ||||
|             feats_v2 = attn_v2 @ v_v2 | ||||
|             feats = torch.cat([feats_v1, feats_v2], dim=-1) | ||||
|         return feats | ||||
|         return feats_v1 | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check the num_heads: | ||||
| @@ -141,12 +136,18 @@ class SuperSelfAttention(SuperModule): | ||||
|         else: | ||||
|             num_heads = spaces.get_determined_value(self._num_heads) | ||||
|         feats = self.forward_qkv(input, num_heads) | ||||
|         if self.proj is None: | ||||
|             return feats | ||||
|         else: | ||||
|             outs = self.proj(feats) | ||||
|             outs = self.proj_drop(outs) | ||||
|             return outs | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         feats = self.forward_qkv(input, self.num_heads) | ||||
|         if self.proj is None: | ||||
|             return feats | ||||
|         else: | ||||
|             outs = self.proj(feats) | ||||
|             outs = self.proj_drop(outs) | ||||
|             return outs | ||||
|   | ||||
| @@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule): | ||||
|         num_heads: IntSpaceType, | ||||
|         qkv_bias: BoolSpaceType = False, | ||||
|         mlp_hidden_multiplier: IntSpaceType = 4, | ||||
|         drop: Optional[float] = None, | ||||
|         dropout: Optional[float] = None, | ||||
|         att_dropout: Optional[float] = None, | ||||
|         norm_affine: bool = True, | ||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||
|         order: LayerOrder = LayerOrder.PreNorm, | ||||
| @@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule): | ||||
|             d_model, | ||||
|             num_heads=num_heads, | ||||
|             qkv_bias=qkv_bias, | ||||
|             attn_drop=drop, | ||||
|             proj_drop=drop, | ||||
|             attn_drop=att_dropout, | ||||
|             proj_drop=None, | ||||
|             use_mask=use_mask, | ||||
|         ) | ||||
|         mlp = SuperMLPv2( | ||||
| @@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule): | ||||
|             hidden_multiplier=mlp_hidden_multiplier, | ||||
|             out_features=d_model, | ||||
|             act_layer=act_layer, | ||||
|             drop=drop, | ||||
|             drop=dropout, | ||||
|         ) | ||||
|         if order is LayerOrder.PreNorm: | ||||
|             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||
|             self.mha = mha | ||||
|             self.drop1 = nn.Dropout(drop or 0.0) | ||||
|             self.drop = nn.Dropout(dropout or 0.0) | ||||
|             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||
|             self.mlp = mlp | ||||
|             self.drop2 = nn.Dropout(drop or 0.0) | ||||
|         elif order is LayerOrder.PostNorm: | ||||
|             self.mha = mha | ||||
|             self.drop1 = nn.Dropout(drop or 0.0) | ||||
|             self.drop1 = nn.Dropout(dropout or 0.0) | ||||
|             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||
|             self.mlp = mlp | ||||
|             self.drop2 = nn.Dropout(drop or 0.0) | ||||
|             self.drop2 = nn.Dropout(dropout or 0.0) | ||||
|             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||
|         else: | ||||
|             raise ValueError("Unknown order: {:}".format(order)) | ||||
| @@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule): | ||||
|             if key in abstract_child: | ||||
|                 getattr(self, key).apply_candidate(abstract_child[key]) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|     def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(inputs) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|     def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor: | ||||
|         if self._order is LayerOrder.PreNorm: | ||||
|             x = self.norm1(input) | ||||
|             x = x + self.drop1(self.mha(x)) | ||||
|             x = self.norm2(x) | ||||
|             x = x + self.drop2(self.mlp(x)) | ||||
|             # https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135 | ||||
|             x = self.norm1(inputs) | ||||
|             x = self.mha(x) | ||||
|             x = self.drop(x) | ||||
|             x = x + inputs | ||||
|             # feed-forward layer -- MLP | ||||
|             y = self.norm2(x) | ||||
|             outs = x + self.mlp(y) | ||||
|         elif self._order is LayerOrder.PostNorm: | ||||
|             # https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder | ||||
|             # multi-head attention | ||||
|             x = self.mha(input) | ||||
|             x = x + self.drop1(x) | ||||
|             x = self.mha(inputs) | ||||
|             x = inputs + self.drop1(x) | ||||
|             x = self.norm1(x) | ||||
|             # feed-forward layer | ||||
|             x = x + self.drop2(self.mlp(x)) | ||||
|             x = self.norm2(x) | ||||
|             # feed-forward layer -- MLP | ||||
|             y = self.mlp(x) | ||||
|             y = x + self.drop2(y) | ||||
|             outs = self.norm2(y) | ||||
|         else: | ||||
|             raise ValueError("Unknown order: {:}".format(self._order)) | ||||
|         return x | ||||
|         return outs | ||||
|   | ||||
| @@ -3,7 +3,7 @@ | ||||
| ##################################################### | ||||
| # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||
| ##################################################### | ||||
| import math | ||||
| import copy, math | ||||
| from functools import partial | ||||
| from typing import Optional, Text, List | ||||
|  | ||||
| @@ -35,42 +35,69 @@ def _init_weights(m): | ||||
|  | ||||
|  | ||||
| name2config = { | ||||
|     "vit-base": dict( | ||||
|     "vit-cifar10-p4-d4-h4-c32": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         image_size=32, | ||||
|         patch_size=4, | ||||
|         num_classes=10, | ||||
|         dim=32, | ||||
|         depth=4, | ||||
|         heads=4, | ||||
|         dropout=0.1, | ||||
|         att_dropout=0.0, | ||||
|     ), | ||||
|     "vit-base-16": dict( | ||||
|         type="vit", | ||||
|         image_size=224, | ||||
|         patch_size=16, | ||||
|         num_classes=1000, | ||||
|         dim=768, | ||||
|         depth=12, | ||||
|         heads=12, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|         att_dropout=0.0, | ||||
|     ), | ||||
|     "vit-large": dict( | ||||
|     "vit-large-16": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         image_size=224, | ||||
|         patch_size=16, | ||||
|         num_classes=1000, | ||||
|         dim=1024, | ||||
|         depth=24, | ||||
|         heads=16, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|         att_dropout=0.0, | ||||
|     ), | ||||
|     "vit-huge": dict( | ||||
|     "vit-huge-14": dict( | ||||
|         type="vit", | ||||
|         image_size=256, | ||||
|         patch_size=16, | ||||
|         image_size=224, | ||||
|         patch_size=14, | ||||
|         num_classes=1000, | ||||
|         dim=1280, | ||||
|         depth=32, | ||||
|         heads=16, | ||||
|         dropout=0.1, | ||||
|         emb_dropout=0.1, | ||||
|         att_dropout=0.0, | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| def extend_cifar100(configs): | ||||
|     new_configs = dict() | ||||
|     for name, config in configs.items(): | ||||
|         new_configs[name] = config | ||||
|         if "cifar10" in name and "cifar100" not in name: | ||||
|             config = copy.deepcopy(config) | ||||
|             config["num_classes"] = 100 | ||||
|             a, b = name.split("cifar10") | ||||
|             new_name = "{:}cifar100{:}".format(a, b) | ||||
|             new_configs[new_name] = config | ||||
|     return new_configs | ||||
|  | ||||
|  | ||||
| name2config = extend_cifar100(name2config) | ||||
|  | ||||
|  | ||||
| class SuperViT(xlayers.SuperModule): | ||||
|     """The super model for transformer.""" | ||||
|  | ||||
| @@ -85,7 +112,7 @@ class SuperViT(xlayers.SuperModule): | ||||
|         mlp_multiplier=4, | ||||
|         channels=3, | ||||
|         dropout=0.0, | ||||
|         emb_dropout=0.0, | ||||
|         att_dropout=0.0, | ||||
|     ): | ||||
|         super(SuperViT, self).__init__() | ||||
|         image_height, image_width = pair(image_size) | ||||
| @@ -107,14 +134,19 @@ class SuperViT(xlayers.SuperModule): | ||||
|  | ||||
|         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||
|         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||
|         self.dropout = nn.Dropout(emb_dropout) | ||||
|         self.dropout = nn.Dropout(dropout) | ||||
|  | ||||
|         # build the transformer encode layers | ||||
|         layers = [] | ||||
|         for ilayer in range(depth): | ||||
|             layers.append( | ||||
|                 xlayers.SuperTransformerEncoderLayer( | ||||
|                     dim, heads, False, mlp_multiplier, dropout | ||||
|                     dim, | ||||
|                     heads, | ||||
|                     False, | ||||
|                     mlp_multiplier, | ||||
|                     dropout=dropout, | ||||
|                     att_dropout=att_dropout, | ||||
|                 ) | ||||
|             ) | ||||
|         self.backbone = xlayers.SuperSequential(*layers) | ||||
| @@ -167,7 +199,7 @@ def get_transformer(config): | ||||
|             depth=config.get("depth"), | ||||
|             heads=config.get("heads"), | ||||
|             dropout=config.get("dropout"), | ||||
|             emb_dropout=config.get("emb_dropout"), | ||||
|             att_dropout=config.get("att_dropout"), | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user