Update SuperViT
This commit is contained in:
		| @@ -10,20 +10,31 @@ import torch | |||||||
| from xautodl.xmodels import transformers | from xautodl.xmodels import transformers | ||||||
| from xautodl.utils.flop_benchmark import count_parameters | from xautodl.utils.flop_benchmark import count_parameters | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSuperViT(unittest.TestCase): | class TestSuperViT(unittest.TestCase): | ||||||
|     """Test the super re-arrange layer.""" |     """Test the super re-arrange layer.""" | ||||||
|  |  | ||||||
|     def test_super_vit(self): |     def test_super_vit(self): | ||||||
|         model = transformers.get_transformer("vit-base") |         model = transformers.get_transformer("vit-base-16") | ||||||
|         tensor = torch.rand((16, 3, 256, 256)) |         tensor = torch.rand((16, 3, 224, 224)) | ||||||
|         print("The tensor shape: {:}".format(tensor.shape)) |         print("The tensor shape: {:}".format(tensor.shape)) | ||||||
|         print(model) |         # print(model) | ||||||
|         outs = model(tensor) |         outs = model(tensor) | ||||||
|         print("The output tensor shape: {:}".format(outs.shape)) |         print("The output tensor shape: {:}".format(outs.shape)) | ||||||
|  |  | ||||||
|     def test_model_size(self): |     def test_imagenet(self): | ||||||
|         name2config = transformers.name2config |         name2config = transformers.name2config | ||||||
|  |         print("There are {:} models in total.".format(len(name2config))) | ||||||
|         for name, config in name2config.items(): |         for name, config in name2config.items(): | ||||||
|  |             if "cifar" in name: | ||||||
|  |                 tensor = torch.rand((16, 3, 32, 32)) | ||||||
|  |             else: | ||||||
|  |                 tensor = torch.rand((16, 3, 224, 224)) | ||||||
|             model = transformers.get_transformer(config) |             model = transformers.get_transformer(config) | ||||||
|  |             outs = model(tensor) | ||||||
|             size = count_parameters(model, "mb", True) |             size = count_parameters(model, "mb", True) | ||||||
|             print('{:10s} : size={:.2f}MB'.format(name, size)) |             print( | ||||||
|  |                 "{:10s} : size={:.2f}MB, out-shape: {:}".format( | ||||||
|  |                     name, size, tuple(outs.shape) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ from xautodl import spaces | |||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_module import IntSpaceType | from .super_module import IntSpaceType | ||||||
| from .super_module import BoolSpaceType | from .super_module import BoolSpaceType | ||||||
|  | from .super_dropout import SuperDropout, SuperDrop | ||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -22,7 +23,7 @@ class SuperSelfAttention(SuperModule): | |||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         input_dim: IntSpaceType, |         input_dim: IntSpaceType, | ||||||
|         proj_dim: IntSpaceType, |         proj_dim: Optional[IntSpaceType], | ||||||
|         num_heads: IntSpaceType, |         num_heads: IntSpaceType, | ||||||
|         qkv_bias: BoolSpaceType = False, |         qkv_bias: BoolSpaceType = False, | ||||||
|         attn_drop: Optional[float] = None, |         attn_drop: Optional[float] = None, | ||||||
| @@ -37,13 +38,17 @@ class SuperSelfAttention(SuperModule): | |||||||
|         self._use_mask = use_mask |         self._use_mask = use_mask | ||||||
|         self._infinity = 1e9 |         self._infinity = 1e9 | ||||||
|  |  | ||||||
|         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         mul_head_dim = (input_dim // num_heads) * num_heads | ||||||
|         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||||
|         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||||
|  |         self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) | ||||||
|  |  | ||||||
|         self.attn_drop = nn.Dropout(attn_drop or 0.0) |         self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True) | ||||||
|         self.proj = SuperLinear(input_dim, proj_dim) |         if proj_dim is None: | ||||||
|         self.proj_drop = nn.Dropout(proj_drop or 0.0) |             self.proj = SuperLinear(input_dim, proj_dim) | ||||||
|  |             self.proj_drop = SuperDropout(proj_drop or 0.0) | ||||||
|  |         else: | ||||||
|  |             self.proj = None | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def num_heads(self): |     def num_heads(self): | ||||||
| @@ -63,7 +68,6 @@ class SuperSelfAttention(SuperModule): | |||||||
|         space_q = self.q_fc.abstract_search_space |         space_q = self.q_fc.abstract_search_space | ||||||
|         space_k = self.k_fc.abstract_search_space |         space_k = self.k_fc.abstract_search_space | ||||||
|         space_v = self.v_fc.abstract_search_space |         space_v = self.v_fc.abstract_search_space | ||||||
|         space_proj = self.proj.abstract_search_space |  | ||||||
|         if not spaces.is_determined(self._num_heads): |         if not spaces.is_determined(self._num_heads): | ||||||
|             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) |             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) | ||||||
|         if not spaces.is_determined(space_q): |         if not spaces.is_determined(space_q): | ||||||
| @@ -72,8 +76,10 @@ class SuperSelfAttention(SuperModule): | |||||||
|             root_node.append("k_fc", space_k) |             root_node.append("k_fc", space_k) | ||||||
|         if not spaces.is_determined(space_v): |         if not spaces.is_determined(space_v): | ||||||
|             root_node.append("v_fc", space_v) |             root_node.append("v_fc", space_v) | ||||||
|         if not spaces.is_determined(space_proj): |         if self.proj is not None: | ||||||
|             root_node.append("proj", space_proj) |             space_proj = self.proj.abstract_search_space | ||||||
|  |             if not spaces.is_determined(space_proj): | ||||||
|  |                 root_node.append("proj", space_proj) | ||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
| @@ -121,18 +127,7 @@ class SuperSelfAttention(SuperModule): | |||||||
|         attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * N |         attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * N | ||||||
|         attn_v1 = self.attn_drop(attn_v1) |         attn_v1 = self.attn_drop(attn_v1) | ||||||
|         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) |         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||||
|         if C == head_dim * num_head: |         return feats_v1 | ||||||
|             feats = feats_v1 |  | ||||||
|         else:  # The channels can not be divided by num_head, the remainder forms an additional head |  | ||||||
|             q_v2 = q[:, :, num_head * head_dim :] |  | ||||||
|             k_v2 = k[:, :, num_head * head_dim :] |  | ||||||
|             v_v2 = v[:, :, num_head * head_dim :] |  | ||||||
|             attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) |  | ||||||
|             attn_v2 = attn_v2.softmax(dim=-1) |  | ||||||
|             attn_v2 = self.attn_drop(attn_v2) |  | ||||||
|             feats_v2 = attn_v2 @ v_v2 |  | ||||||
|             feats = torch.cat([feats_v1, feats_v2], dim=-1) |  | ||||||
|         return feats |  | ||||||
|  |  | ||||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         # check the num_heads: |         # check the num_heads: | ||||||
| @@ -141,15 +136,21 @@ class SuperSelfAttention(SuperModule): | |||||||
|         else: |         else: | ||||||
|             num_heads = spaces.get_determined_value(self._num_heads) |             num_heads = spaces.get_determined_value(self._num_heads) | ||||||
|         feats = self.forward_qkv(input, num_heads) |         feats = self.forward_qkv(input, num_heads) | ||||||
|         outs = self.proj(feats) |         if self.proj is None: | ||||||
|         outs = self.proj_drop(outs) |             return feats | ||||||
|         return outs |         else: | ||||||
|  |             outs = self.proj(feats) | ||||||
|  |             outs = self.proj_drop(outs) | ||||||
|  |             return outs | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         feats = self.forward_qkv(input, self.num_heads) |         feats = self.forward_qkv(input, self.num_heads) | ||||||
|         outs = self.proj(feats) |         if self.proj is None: | ||||||
|         outs = self.proj_drop(outs) |             return feats | ||||||
|         return outs |         else: | ||||||
|  |             outs = self.proj(feats) | ||||||
|  |             outs = self.proj_drop(outs) | ||||||
|  |             return outs | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return ( |         return ( | ||||||
|   | |||||||
| @@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule): | |||||||
|         num_heads: IntSpaceType, |         num_heads: IntSpaceType, | ||||||
|         qkv_bias: BoolSpaceType = False, |         qkv_bias: BoolSpaceType = False, | ||||||
|         mlp_hidden_multiplier: IntSpaceType = 4, |         mlp_hidden_multiplier: IntSpaceType = 4, | ||||||
|         drop: Optional[float] = None, |         dropout: Optional[float] = None, | ||||||
|  |         att_dropout: Optional[float] = None, | ||||||
|         norm_affine: bool = True, |         norm_affine: bool = True, | ||||||
|         act_layer: Callable[[], nn.Module] = nn.GELU, |         act_layer: Callable[[], nn.Module] = nn.GELU, | ||||||
|         order: LayerOrder = LayerOrder.PreNorm, |         order: LayerOrder = LayerOrder.PreNorm, | ||||||
| @@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule): | |||||||
|             d_model, |             d_model, | ||||||
|             num_heads=num_heads, |             num_heads=num_heads, | ||||||
|             qkv_bias=qkv_bias, |             qkv_bias=qkv_bias, | ||||||
|             attn_drop=drop, |             attn_drop=att_dropout, | ||||||
|             proj_drop=drop, |             proj_drop=None, | ||||||
|             use_mask=use_mask, |             use_mask=use_mask, | ||||||
|         ) |         ) | ||||||
|         mlp = SuperMLPv2( |         mlp = SuperMLPv2( | ||||||
| @@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule): | |||||||
|             hidden_multiplier=mlp_hidden_multiplier, |             hidden_multiplier=mlp_hidden_multiplier, | ||||||
|             out_features=d_model, |             out_features=d_model, | ||||||
|             act_layer=act_layer, |             act_layer=act_layer, | ||||||
|             drop=drop, |             drop=dropout, | ||||||
|         ) |         ) | ||||||
|         if order is LayerOrder.PreNorm: |         if order is LayerOrder.PreNorm: | ||||||
|             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) |             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||||
|             self.mha = mha |             self.mha = mha | ||||||
|             self.drop1 = nn.Dropout(drop or 0.0) |             self.drop = nn.Dropout(dropout or 0.0) | ||||||
|             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) |             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||||
|             self.mlp = mlp |             self.mlp = mlp | ||||||
|             self.drop2 = nn.Dropout(drop or 0.0) |  | ||||||
|         elif order is LayerOrder.PostNorm: |         elif order is LayerOrder.PostNorm: | ||||||
|             self.mha = mha |             self.mha = mha | ||||||
|             self.drop1 = nn.Dropout(drop or 0.0) |             self.drop1 = nn.Dropout(dropout or 0.0) | ||||||
|             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) |             self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||||
|             self.mlp = mlp |             self.mlp = mlp | ||||||
|             self.drop2 = nn.Dropout(drop or 0.0) |             self.drop2 = nn.Dropout(dropout or 0.0) | ||||||
|             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) |             self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unknown order: {:}".format(order)) |             raise ValueError("Unknown order: {:}".format(order)) | ||||||
| @@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule): | |||||||
|             if key in abstract_child: |             if key in abstract_child: | ||||||
|                 getattr(self, key).apply_candidate(abstract_child[key]) |                 getattr(self, key).apply_candidate(abstract_child[key]) | ||||||
|  |  | ||||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: |     def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor: | ||||||
|         return self.forward_raw(input) |         return self.forward_raw(inputs) | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor: | ||||||
|         if self._order is LayerOrder.PreNorm: |         if self._order is LayerOrder.PreNorm: | ||||||
|             x = self.norm1(input) |             # https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135 | ||||||
|             x = x + self.drop1(self.mha(x)) |             x = self.norm1(inputs) | ||||||
|             x = self.norm2(x) |             x = self.mha(x) | ||||||
|             x = x + self.drop2(self.mlp(x)) |             x = self.drop(x) | ||||||
|  |             x = x + inputs | ||||||
|  |             # feed-forward layer -- MLP | ||||||
|  |             y = self.norm2(x) | ||||||
|  |             outs = x + self.mlp(y) | ||||||
|         elif self._order is LayerOrder.PostNorm: |         elif self._order is LayerOrder.PostNorm: | ||||||
|  |             # https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder | ||||||
|             # multi-head attention |             # multi-head attention | ||||||
|             x = self.mha(input) |             x = self.mha(inputs) | ||||||
|             x = x + self.drop1(x) |             x = inputs + self.drop1(x) | ||||||
|             x = self.norm1(x) |             x = self.norm1(x) | ||||||
|             # feed-forward layer |             # feed-forward layer -- MLP | ||||||
|             x = x + self.drop2(self.mlp(x)) |             y = self.mlp(x) | ||||||
|             x = self.norm2(x) |             y = x + self.drop2(y) | ||||||
|  |             outs = self.norm2(y) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unknown order: {:}".format(self._order)) |             raise ValueError("Unknown order: {:}".format(self._order)) | ||||||
|         return x |         return outs | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | # Vision Transformer: arxiv.org/pdf/2010.11929.pdf  # | ||||||
| ##################################################### | ##################################################### | ||||||
| import math | import copy, math | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional, Text, List | from typing import Optional, Text, List | ||||||
|  |  | ||||||
| @@ -35,42 +35,69 @@ def _init_weights(m): | |||||||
|  |  | ||||||
|  |  | ||||||
| name2config = { | name2config = { | ||||||
|     "vit-base": dict( |     "vit-cifar10-p4-d4-h4-c32": dict( | ||||||
|         type="vit", |         type="vit", | ||||||
|         image_size=256, |         image_size=32, | ||||||
|  |         patch_size=4, | ||||||
|  |         num_classes=10, | ||||||
|  |         dim=32, | ||||||
|  |         depth=4, | ||||||
|  |         heads=4, | ||||||
|  |         dropout=0.1, | ||||||
|  |         att_dropout=0.0, | ||||||
|  |     ), | ||||||
|  |     "vit-base-16": dict( | ||||||
|  |         type="vit", | ||||||
|  |         image_size=224, | ||||||
|         patch_size=16, |         patch_size=16, | ||||||
|         num_classes=1000, |         num_classes=1000, | ||||||
|         dim=768, |         dim=768, | ||||||
|         depth=12, |         depth=12, | ||||||
|         heads=12, |         heads=12, | ||||||
|         dropout=0.1, |         dropout=0.1, | ||||||
|         emb_dropout=0.1, |         att_dropout=0.0, | ||||||
|     ), |     ), | ||||||
|     "vit-large": dict( |     "vit-large-16": dict( | ||||||
|         type="vit", |         type="vit", | ||||||
|         image_size=256, |         image_size=224, | ||||||
|         patch_size=16, |         patch_size=16, | ||||||
|         num_classes=1000, |         num_classes=1000, | ||||||
|         dim=1024, |         dim=1024, | ||||||
|         depth=24, |         depth=24, | ||||||
|         heads=16, |         heads=16, | ||||||
|         dropout=0.1, |         dropout=0.1, | ||||||
|         emb_dropout=0.1, |         att_dropout=0.0, | ||||||
|     ), |     ), | ||||||
|     "vit-huge": dict( |     "vit-huge-14": dict( | ||||||
|         type="vit", |         type="vit", | ||||||
|         image_size=256, |         image_size=224, | ||||||
|         patch_size=16, |         patch_size=14, | ||||||
|         num_classes=1000, |         num_classes=1000, | ||||||
|         dim=1280, |         dim=1280, | ||||||
|         depth=32, |         depth=32, | ||||||
|         heads=16, |         heads=16, | ||||||
|         dropout=0.1, |         dropout=0.1, | ||||||
|         emb_dropout=0.1, |         att_dropout=0.0, | ||||||
|     ), |     ), | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def extend_cifar100(configs): | ||||||
|  |     new_configs = dict() | ||||||
|  |     for name, config in configs.items(): | ||||||
|  |         new_configs[name] = config | ||||||
|  |         if "cifar10" in name and "cifar100" not in name: | ||||||
|  |             config = copy.deepcopy(config) | ||||||
|  |             config["num_classes"] = 100 | ||||||
|  |             a, b = name.split("cifar10") | ||||||
|  |             new_name = "{:}cifar100{:}".format(a, b) | ||||||
|  |             new_configs[new_name] = config | ||||||
|  |     return new_configs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | name2config = extend_cifar100(name2config) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperViT(xlayers.SuperModule): | class SuperViT(xlayers.SuperModule): | ||||||
|     """The super model for transformer.""" |     """The super model for transformer.""" | ||||||
|  |  | ||||||
| @@ -85,7 +112,7 @@ class SuperViT(xlayers.SuperModule): | |||||||
|         mlp_multiplier=4, |         mlp_multiplier=4, | ||||||
|         channels=3, |         channels=3, | ||||||
|         dropout=0.0, |         dropout=0.0, | ||||||
|         emb_dropout=0.0, |         att_dropout=0.0, | ||||||
|     ): |     ): | ||||||
|         super(SuperViT, self).__init__() |         super(SuperViT, self).__init__() | ||||||
|         image_height, image_width = pair(image_size) |         image_height, image_width = pair(image_size) | ||||||
| @@ -107,14 +134,19 @@ class SuperViT(xlayers.SuperModule): | |||||||
|  |  | ||||||
|         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | ||||||
|         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |         self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||||||
|         self.dropout = nn.Dropout(emb_dropout) |         self.dropout = nn.Dropout(dropout) | ||||||
|  |  | ||||||
|         # build the transformer encode layers |         # build the transformer encode layers | ||||||
|         layers = [] |         layers = [] | ||||||
|         for ilayer in range(depth): |         for ilayer in range(depth): | ||||||
|             layers.append( |             layers.append( | ||||||
|                 xlayers.SuperTransformerEncoderLayer( |                 xlayers.SuperTransformerEncoderLayer( | ||||||
|                     dim, heads, False, mlp_multiplier, dropout |                     dim, | ||||||
|  |                     heads, | ||||||
|  |                     False, | ||||||
|  |                     mlp_multiplier, | ||||||
|  |                     dropout=dropout, | ||||||
|  |                     att_dropout=att_dropout, | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         self.backbone = xlayers.SuperSequential(*layers) |         self.backbone = xlayers.SuperSequential(*layers) | ||||||
| @@ -167,7 +199,7 @@ def get_transformer(config): | |||||||
|             depth=config.get("depth"), |             depth=config.get("depth"), | ||||||
|             heads=config.get("heads"), |             heads=config.get("heads"), | ||||||
|             dropout=config.get("dropout"), |             dropout=config.get("dropout"), | ||||||
|             emb_dropout=config.get("emb_dropout"), |             att_dropout=config.get("att_dropout"), | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown model type: {:}".format(model_type)) |         raise ValueError("Unknown model type: {:}".format(model_type)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user