diff --git a/tests/test_super_vit.py b/tests/test_super_vit.py index 903f71c..1b5d390 100644 --- a/tests/test_super_vit.py +++ b/tests/test_super_vit.py @@ -10,20 +10,31 @@ import torch from xautodl.xmodels import transformers from xautodl.utils.flop_benchmark import count_parameters + class TestSuperViT(unittest.TestCase): """Test the super re-arrange layer.""" def test_super_vit(self): - model = transformers.get_transformer("vit-base") - tensor = torch.rand((16, 3, 256, 256)) + model = transformers.get_transformer("vit-base-16") + tensor = torch.rand((16, 3, 224, 224)) print("The tensor shape: {:}".format(tensor.shape)) - print(model) + # print(model) outs = model(tensor) print("The output tensor shape: {:}".format(outs.shape)) - def test_model_size(self): + def test_imagenet(self): name2config = transformers.name2config + print("There are {:} models in total.".format(len(name2config))) for name, config in name2config.items(): + if "cifar" in name: + tensor = torch.rand((16, 3, 32, 32)) + else: + tensor = torch.rand((16, 3, 224, 224)) model = transformers.get_transformer(config) + outs = model(tensor) size = count_parameters(model, "mb", True) - print('{:10s} : size={:.2f}MB'.format(name, size)) + print( + "{:10s} : size={:.2f}MB, out-shape: {:}".format( + name, size, tuple(outs.shape) + ) + ) diff --git a/xautodl/xlayers/super_attention.py b/xautodl/xlayers/super_attention.py index 2c3c591..c70ea38 100644 --- a/xautodl/xlayers/super_attention.py +++ b/xautodl/xlayers/super_attention.py @@ -13,6 +13,7 @@ from xautodl import spaces from .super_module import SuperModule from .super_module import IntSpaceType from .super_module import BoolSpaceType +from .super_dropout import SuperDropout, SuperDrop from .super_linear import SuperLinear @@ -22,7 +23,7 @@ class SuperSelfAttention(SuperModule): def __init__( self, input_dim: IntSpaceType, - proj_dim: IntSpaceType, + proj_dim: Optional[IntSpaceType], num_heads: IntSpaceType, qkv_bias: BoolSpaceType = False, attn_drop: Optional[float] = None, @@ -37,13 +38,17 @@ class SuperSelfAttention(SuperModule): self._use_mask = use_mask self._infinity = 1e9 - self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) - self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) - self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) + mul_head_dim = (input_dim // num_heads) * num_heads + self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) + self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) + self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop or 0.0) - self.proj = SuperLinear(input_dim, proj_dim) - self.proj_drop = nn.Dropout(proj_drop or 0.0) + self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True) + if proj_dim is None: + self.proj = SuperLinear(input_dim, proj_dim) + self.proj_drop = SuperDropout(proj_drop or 0.0) + else: + self.proj = None @property def num_heads(self): @@ -63,7 +68,6 @@ class SuperSelfAttention(SuperModule): space_q = self.q_fc.abstract_search_space space_k = self.k_fc.abstract_search_space space_v = self.v_fc.abstract_search_space - space_proj = self.proj.abstract_search_space if not spaces.is_determined(self._num_heads): root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) if not spaces.is_determined(space_q): @@ -72,8 +76,10 @@ class SuperSelfAttention(SuperModule): root_node.append("k_fc", space_k) if not spaces.is_determined(space_v): root_node.append("v_fc", space_v) - if not spaces.is_determined(space_proj): - root_node.append("proj", space_proj) + if self.proj is not None: + space_proj = self.proj.abstract_search_space + if not spaces.is_determined(space_proj): + root_node.append("proj", space_proj) return root_node def apply_candidate(self, abstract_child: spaces.VirtualNode): @@ -121,18 +127,7 @@ class SuperSelfAttention(SuperModule): attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N attn_v1 = self.attn_drop(attn_v1) feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) - if C == head_dim * num_head: - feats = feats_v1 - else: # The channels can not be divided by num_head, the remainder forms an additional head - q_v2 = q[:, :, num_head * head_dim :] - k_v2 = k[:, :, num_head * head_dim :] - v_v2 = v[:, :, num_head * head_dim :] - attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) - attn_v2 = attn_v2.softmax(dim=-1) - attn_v2 = self.attn_drop(attn_v2) - feats_v2 = attn_v2 @ v_v2 - feats = torch.cat([feats_v1, feats_v2], dim=-1) - return feats + return feats_v1 def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: # check the num_heads: @@ -141,15 +136,21 @@ class SuperSelfAttention(SuperModule): else: num_heads = spaces.get_determined_value(self._num_heads) feats = self.forward_qkv(input, num_heads) - outs = self.proj(feats) - outs = self.proj_drop(outs) - return outs + if self.proj is None: + return feats + else: + outs = self.proj(feats) + outs = self.proj_drop(outs) + return outs def forward_raw(self, input: torch.Tensor) -> torch.Tensor: feats = self.forward_qkv(input, self.num_heads) - outs = self.proj(feats) - outs = self.proj_drop(outs) - return outs + if self.proj is None: + return feats + else: + outs = self.proj(feats) + outs = self.proj_drop(outs) + return outs def extra_repr(self) -> str: return ( diff --git a/xautodl/xlayers/super_transformer.py b/xautodl/xlayers/super_transformer.py index 326188e..b45ca51 100644 --- a/xautodl/xlayers/super_transformer.py +++ b/xautodl/xlayers/super_transformer.py @@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule): num_heads: IntSpaceType, qkv_bias: BoolSpaceType = False, mlp_hidden_multiplier: IntSpaceType = 4, - drop: Optional[float] = None, + dropout: Optional[float] = None, + att_dropout: Optional[float] = None, norm_affine: bool = True, act_layer: Callable[[], nn.Module] = nn.GELU, order: LayerOrder = LayerOrder.PreNorm, @@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule): d_model, num_heads=num_heads, qkv_bias=qkv_bias, - attn_drop=drop, - proj_drop=drop, + attn_drop=att_dropout, + proj_drop=None, use_mask=use_mask, ) mlp = SuperMLPv2( @@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule): hidden_multiplier=mlp_hidden_multiplier, out_features=d_model, act_layer=act_layer, - drop=drop, + drop=dropout, ) if order is LayerOrder.PreNorm: self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mha = mha - self.drop1 = nn.Dropout(drop or 0.0) + self.drop = nn.Dropout(dropout or 0.0) self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mlp = mlp - self.drop2 = nn.Dropout(drop or 0.0) elif order is LayerOrder.PostNorm: self.mha = mha - self.drop1 = nn.Dropout(drop or 0.0) + self.drop1 = nn.Dropout(dropout or 0.0) self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mlp = mlp - self.drop2 = nn.Dropout(drop or 0.0) + self.drop2 = nn.Dropout(dropout or 0.0) self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) else: raise ValueError("Unknown order: {:}".format(order)) @@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule): if key in abstract_child: getattr(self, key).apply_candidate(abstract_child[key]) - def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: - return self.forward_raw(input) + def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor: + return self.forward_raw(inputs) - def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor: if self._order is LayerOrder.PreNorm: - x = self.norm1(input) - x = x + self.drop1(self.mha(x)) - x = self.norm2(x) - x = x + self.drop2(self.mlp(x)) + # https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135 + x = self.norm1(inputs) + x = self.mha(x) + x = self.drop(x) + x = x + inputs + # feed-forward layer -- MLP + y = self.norm2(x) + outs = x + self.mlp(y) elif self._order is LayerOrder.PostNorm: + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder # multi-head attention - x = self.mha(input) - x = x + self.drop1(x) + x = self.mha(inputs) + x = inputs + self.drop1(x) x = self.norm1(x) - # feed-forward layer - x = x + self.drop2(self.mlp(x)) - x = self.norm2(x) + # feed-forward layer -- MLP + y = self.mlp(x) + y = x + self.drop2(y) + outs = self.norm2(y) else: raise ValueError("Unknown order: {:}".format(self._order)) - return x + return outs diff --git a/xautodl/xmodels/transformers.py b/xautodl/xmodels/transformers.py index bc67b37..a4204d2 100644 --- a/xautodl/xmodels/transformers.py +++ b/xautodl/xmodels/transformers.py @@ -3,7 +3,7 @@ ##################################################### # Vision Transformer: arxiv.org/pdf/2010.11929.pdf # ##################################################### -import math +import copy, math from functools import partial from typing import Optional, Text, List @@ -35,42 +35,69 @@ def _init_weights(m): name2config = { - "vit-base": dict( + "vit-cifar10-p4-d4-h4-c32": dict( type="vit", - image_size=256, + image_size=32, + patch_size=4, + num_classes=10, + dim=32, + depth=4, + heads=4, + dropout=0.1, + att_dropout=0.0, + ), + "vit-base-16": dict( + type="vit", + image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, dropout=0.1, - emb_dropout=0.1, + att_dropout=0.0, ), - "vit-large": dict( + "vit-large-16": dict( type="vit", - image_size=256, + image_size=224, patch_size=16, num_classes=1000, dim=1024, depth=24, heads=16, dropout=0.1, - emb_dropout=0.1, + att_dropout=0.0, ), - "vit-huge": dict( + "vit-huge-14": dict( type="vit", - image_size=256, - patch_size=16, + image_size=224, + patch_size=14, num_classes=1000, dim=1280, depth=32, heads=16, dropout=0.1, - emb_dropout=0.1, + att_dropout=0.0, ), } +def extend_cifar100(configs): + new_configs = dict() + for name, config in configs.items(): + new_configs[name] = config + if "cifar10" in name and "cifar100" not in name: + config = copy.deepcopy(config) + config["num_classes"] = 100 + a, b = name.split("cifar10") + new_name = "{:}cifar100{:}".format(a, b) + new_configs[new_name] = config + return new_configs + + +name2config = extend_cifar100(name2config) + + class SuperViT(xlayers.SuperModule): """The super model for transformer.""" @@ -85,7 +112,7 @@ class SuperViT(xlayers.SuperModule): mlp_multiplier=4, channels=3, dropout=0.0, - emb_dropout=0.0, + att_dropout=0.0, ): super(SuperViT, self).__init__() image_height, image_width = pair(image_size) @@ -107,14 +134,19 @@ class SuperViT(xlayers.SuperModule): self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.dropout = nn.Dropout(emb_dropout) + self.dropout = nn.Dropout(dropout) # build the transformer encode layers layers = [] for ilayer in range(depth): layers.append( xlayers.SuperTransformerEncoderLayer( - dim, heads, False, mlp_multiplier, dropout + dim, + heads, + False, + mlp_multiplier, + dropout=dropout, + att_dropout=att_dropout, ) ) self.backbone = xlayers.SuperSequential(*layers) @@ -167,7 +199,7 @@ def get_transformer(config): depth=config.get("depth"), heads=config.get("heads"), dropout=config.get("dropout"), - emb_dropout=config.get("emb_dropout"), + att_dropout=config.get("att_dropout"), ) else: raise ValueError("Unknown model type: {:}".format(model_type))