From 0ddc5c0dc41f6b231c1e637dcb3d090b4ff051cb Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 9 Jun 2021 02:16:56 -0700 Subject: [PATCH] Update ViT --- tests/test_super_vit.py | 29 +++ xautodl/xlayers/super_mlp.py | 319 ++++++++++++++++++++++++++++++++ xautodl/xmodels/__init__.py | 2 + xautodl/xmodels/transformers.py | 271 +++++++++++++-------------- 4 files changed, 475 insertions(+), 146 deletions(-) create mode 100644 tests/test_super_vit.py create mode 100644 xautodl/xlayers/super_mlp.py diff --git a/tests/test_super_vit.py b/tests/test_super_vit.py new file mode 100644 index 0000000..903f71c --- /dev/null +++ b/tests/test_super_vit.py @@ -0,0 +1,29 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest ./tests/test_super_vit.py -s # +##################################################### +import sys +import unittest + +import torch +from xautodl.xmodels import transformers +from xautodl.utils.flop_benchmark import count_parameters + +class TestSuperViT(unittest.TestCase): + """Test the super re-arrange layer.""" + + def test_super_vit(self): + model = transformers.get_transformer("vit-base") + tensor = torch.rand((16, 3, 256, 256)) + print("The tensor shape: {:}".format(tensor.shape)) + print(model) + outs = model(tensor) + print("The output tensor shape: {:}".format(outs.shape)) + + def test_model_size(self): + name2config = transformers.name2config + for name, config in name2config.items(): + model = transformers.get_transformer(config) + size = count_parameters(model, "mb", True) + print('{:10s} : size={:.2f}MB'.format(name, size)) diff --git a/xautodl/xlayers/super_mlp.py b/xautodl/xlayers/super_mlp.py new file mode 100644 index 0000000..f33a6b2 --- /dev/null +++ b/xautodl/xlayers/super_mlp.py @@ -0,0 +1,319 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, Callable + +from xautodl import spaces +from .super_module import SuperModule +from .super_module import IntSpaceType +from .super_module import BoolSpaceType + + +class SuperLinear(SuperModule): + """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`""" + + def __init__( + self, + in_features: IntSpaceType, + out_features: IntSpaceType, + bias: BoolSpaceType = True, + ) -> None: + super(SuperLinear, self).__init__() + + # the raw input args + self._in_features = in_features + self._out_features = out_features + self._bias = bias + # weights to be optimized + self.register_parameter( + "_super_weight", + torch.nn.Parameter(torch.Tensor(self.out_features, self.in_features)), + ) + if self.bias: + self.register_parameter( + "_super_bias", torch.nn.Parameter(torch.Tensor(self.out_features)) + ) + else: + self.register_parameter("_super_bias", None) + self.reset_parameters() + + @property + def in_features(self): + return spaces.get_max(self._in_features) + + @property + def out_features(self): + return spaces.get_max(self._out_features) + + @property + def bias(self): + return spaces.has_categorical(self._bias, True) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + if not spaces.is_determined(self._in_features): + root_node.append( + "_in_features", self._in_features.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._out_features): + root_node.append( + "_out_features", self._out_features.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._bias): + root_node.append("_bias", self._bias.abstract(reuse_last=True)) + return root_node + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) + if self.bias: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self._super_bias, -bound, bound) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + if not spaces.is_determined(self._in_features): + expected_input_dim = self.abstract_child["_in_features"].value + else: + expected_input_dim = spaces.get_determined_value(self._in_features) + if input.size(-1) != expected_input_dim: + raise ValueError( + "Expect the input dim of {:} instead of {:}".format( + expected_input_dim, input.size(-1) + ) + ) + # create the weight matrix + if not spaces.is_determined(self._out_features): + out_dim = self.abstract_child["_out_features"].value + else: + out_dim = spaces.get_determined_value(self._out_features) + candidate_weight = self._super_weight[:out_dim, :expected_input_dim] + # create the bias matrix + if not spaces.is_determined(self._bias): + if self.abstract_child["_bias"].value: + candidate_bias = self._super_bias[:out_dim] + else: + candidate_bias = None + else: + if spaces.get_determined_value(self._bias): + candidate_bias = self._super_bias[:out_dim] + else: + candidate_bias = None + return F.linear(input, candidate_weight, candidate_bias) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self._super_weight, self._super_bias) + + def extra_repr(self) -> str: + return "in_features={:}, out_features={:}, bias={:}".format( + self._in_features, self._out_features, self._bias + ) + + def forward_with_container(self, input, container, prefix=[]): + super_weight_name = ".".join(prefix + ["_super_weight"]) + super_weight = container.query(super_weight_name) + super_bias_name = ".".join(prefix + ["_super_bias"]) + if container.has(super_bias_name): + super_bias = container.query(super_bias_name) + else: + super_bias = None + return F.linear(input, super_weight, super_bias) + + +class SuperMLPv1(SuperModule): + """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" + + def __init__( + self, + in_features: IntSpaceType, + hidden_features: IntSpaceType, + out_features: IntSpaceType, + act_layer: Callable[[], nn.Module] = nn.GELU, + drop: Optional[float] = None, + ): + super(SuperMLPv1, self).__init__() + self._in_features = in_features + self._hidden_features = hidden_features + self._out_features = out_features + self._drop_rate = drop + self.fc1 = SuperLinear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = SuperLinear(hidden_features, out_features) + self.drop = nn.Dropout(drop or 0.0) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + space_fc1 = self.fc1.abstract_search_space + space_fc2 = self.fc2.abstract_search_space + if not spaces.is_determined(space_fc1): + root_node.append("fc1", space_fc1) + if not spaces.is_determined(space_fc2): + root_node.append("fc2", space_fc2) + return root_node + + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperMLPv1, self).apply_candidate(abstract_child) + if "fc1" in abstract_child: + self.fc1.apply_candidate(abstract_child["fc1"]) + if "fc2" in abstract_child: + self.fc2.apply_candidate(abstract_child["fc2"]) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + x = self.fc1(input) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + def extra_repr(self) -> str: + return "in_features={:}, hidden_features={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( + self._in_features, + self._hidden_features, + self._out_features, + self._drop_rate, + ) + + +class SuperMLPv2(SuperModule): + """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" + + def __init__( + self, + in_features: IntSpaceType, + hidden_multiplier: IntSpaceType, + out_features: IntSpaceType, + act_layer: Callable[[], nn.Module] = nn.GELU, + drop: Optional[float] = None, + ): + super(SuperMLPv2, self).__init__() + self._in_features = in_features + self._hidden_multiplier = hidden_multiplier + self._out_features = out_features + self._drop_rate = drop + self._params = nn.ParameterDict({}) + + self._create_linear( + "fc1", self.in_features, int(self.in_features * self.hidden_multiplier) + ) + self._create_linear( + "fc2", int(self.in_features * self.hidden_multiplier), self.out_features + ) + self.act = act_layer() + self.drop = nn.Dropout(drop or 0.0) + self.reset_parameters() + + @property + def in_features(self): + return spaces.get_max(self._in_features) + + @property + def hidden_multiplier(self): + return spaces.get_max(self._hidden_multiplier) + + @property + def out_features(self): + return spaces.get_max(self._out_features) + + def _create_linear(self, name, inC, outC): + self._params["{:}_super_weight".format(name)] = torch.nn.Parameter( + torch.Tensor(outC, inC) + ) + self._params["{:}_super_bias".format(name)] = torch.nn.Parameter( + torch.Tensor(outC) + ) + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self._params["fc1_super_weight"], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self._params["fc2_super_weight"], a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self._params["fc1_super_weight"] + ) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self._params["fc1_super_bias"], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self._params["fc2_super_weight"] + ) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self._params["fc2_super_bias"], -bound, bound) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + if not spaces.is_determined(self._in_features): + root_node.append( + "_in_features", self._in_features.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._hidden_multiplier): + root_node.append( + "_hidden_multiplier", self._hidden_multiplier.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._out_features): + root_node.append( + "_out_features", self._out_features.abstract(reuse_last=True) + ) + return root_node + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + if not spaces.is_determined(self._in_features): + expected_input_dim = self.abstract_child["_in_features"].value + else: + expected_input_dim = spaces.get_determined_value(self._in_features) + if input.size(-1) != expected_input_dim: + raise ValueError( + "Expect the input dim of {:} instead of {:}".format( + expected_input_dim, input.size(-1) + ) + ) + # create the weight and bias matrix for fc1 + if not spaces.is_determined(self._hidden_multiplier): + hmul = self.abstract_child["_hidden_multiplier"].value * expected_input_dim + else: + hmul = spaces.get_determined_value(self._hidden_multiplier) + hidden_dim = int(expected_input_dim * hmul) + _fc1_weight = self._params["fc1_super_weight"][:hidden_dim, :expected_input_dim] + _fc1_bias = self._params["fc1_super_bias"][:hidden_dim] + x = F.linear(input, _fc1_weight, _fc1_bias) + x = self.act(x) + x = self.drop(x) + # create the weight and bias matrix for fc2 + if not spaces.is_determined(self._out_features): + out_dim = self.abstract_child["_out_features"].value + else: + out_dim = spaces.get_determined_value(self._out_features) + _fc2_weight = self._params["fc2_super_weight"][:out_dim, :hidden_dim] + _fc2_bias = self._params["fc2_super_bias"][:out_dim] + x = F.linear(x, _fc2_weight, _fc2_bias) + x = self.drop(x) + return x + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + x = F.linear( + input, self._params["fc1_super_weight"], self._params["fc1_super_bias"] + ) + x = self.act(x) + x = self.drop(x) + x = F.linear( + x, self._params["fc2_super_weight"], self._params["fc2_super_bias"] + ) + x = self.drop(x) + return x + + def extra_repr(self) -> str: + return "in_features={:}, hidden_multiplier={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( + self._in_features, + self._hidden_multiplier, + self._out_features, + self._drop_rate, + ) diff --git a/xautodl/xmodels/__init__.py b/xautodl/xmodels/__init__.py index d7d6635..04f21fe 100644 --- a/xautodl/xmodels/__init__.py +++ b/xautodl/xmodels/__init__.py @@ -3,3 +3,5 @@ ##################################################### # The models in this folder is written with xlayers # ##################################################### + +from .transformers import get_transformer diff --git a/xautodl/xmodels/transformers.py b/xautodl/xmodels/transformers.py index 09dea0a..bc67b37 100644 --- a/xautodl/xmodels/transformers.py +++ b/xautodl/xmodels/transformers.py @@ -1,6 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # ##################################################### +# Vision Transformer: arxiv.org/pdf/2010.11929.pdf # +##################################################### import math from functools import partial from typing import Optional, Text, List @@ -10,186 +12,163 @@ import torch.nn as nn import torch.nn.functional as F from xautodl import spaces -from xautodl.xlayers import trunc_normal_ -from xautodl.xlayers import super_core +from xautodl import xlayers +from xautodl.xlayers import weight_init -__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] +def pair(t): + return t if isinstance(t, tuple) else (t, t) -def _get_mul_specs(candidates, num): - results = [] - for i in range(num): - results.append(spaces.Categorical(*candidates)) - return results +def _init_weights(m): + if isinstance(m, nn.Linear): + weight_init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, xlayers.SuperLinear): + weight_init.trunc_normal_(m._super_weight, std=0.02) + if m._super_bias is not None: + nn.init.constant_(m._super_bias, 0) + elif isinstance(m, xlayers.SuperLayerNorm1D): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) -def _get_list_mul(num, multipler): - results = [] - for i in range(1, num + 1): - results.append(i * multipler) - return results +name2config = { + "vit-base": dict( + type="vit", + image_size=256, + patch_size=16, + num_classes=1000, + dim=768, + depth=12, + heads=12, + dropout=0.1, + emb_dropout=0.1, + ), + "vit-large": dict( + type="vit", + image_size=256, + patch_size=16, + num_classes=1000, + dim=1024, + depth=24, + heads=16, + dropout=0.1, + emb_dropout=0.1, + ), + "vit-huge": dict( + type="vit", + image_size=256, + patch_size=16, + num_classes=1000, + dim=1280, + depth=32, + heads=16, + dropout=0.1, + emb_dropout=0.1, + ), +} -def _assert_types(x, expected_types): - if not isinstance(x, expected_types): - raise TypeError( - "The type [{:}] is expected to be {:}.".format(type(x), expected_types) - ) - - -DEFAULT_NET_CONFIG = None -_default_max_depth = 5 -DefaultSearchSpace = dict( - d_feat=6, - embed_dim=spaces.Categorical(*_get_list_mul(8, 16)), - num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth), - mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth), - qkv_bias=True, - pos_drop=0.0, - other_drop=0.0, -) - - -class SuperTransformer(super_core.SuperModule): +class SuperViT(xlayers.SuperModule): """The super model for transformer.""" def __init__( self, - d_feat: int = 6, - embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"], - num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], - mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ - "mlp_hidden_multipliers" - ], - qkv_bias: bool = DefaultSearchSpace["qkv_bias"], - pos_drop: float = DefaultSearchSpace["pos_drop"], - other_drop: float = DefaultSearchSpace["other_drop"], - max_seq_len: int = 65, + image_size, + patch_size, + num_classes, + dim, + depth, + heads, + mlp_multiplier=4, + channels=3, + dropout=0.0, + emb_dropout=0.0, ): - super(SuperTransformer, self).__init__() - self._embed_dim = embed_dim - self._num_heads = num_heads - self._mlp_hidden_multipliers = mlp_hidden_multipliers + super(SuperViT, self).__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) - # the stem part - self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim) - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = super_core.SuperPositionalEncoder( - d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop + if image_height % patch_height != 0 or image_width % patch_width != 0: + raise ValueError("Image dimensions must be divisible by the patch size.") + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + self.to_patch_embedding = xlayers.SuperSequential( + xlayers.SuperReArrange( + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=patch_height, + p2=patch_width, + ), + xlayers.SuperLinear(patch_dim, dim), ) - # build the transformer encode layers -->> check params - _assert_types(num_heads, (tuple, list)) - _assert_types(mlp_hidden_multipliers, (tuple, list)) - assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format( - len(num_heads), len(mlp_hidden_multipliers) - ) - # build the transformer encode layers -->> backbone + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + # build the transformer encode layers layers = [] - for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers): - layer = super_core.SuperTransformerEncoderLayer( - embed_dim, - num_head, - qkv_bias, - mlp_hidden_multiplier, - other_drop, + for ilayer in range(depth): + layers.append( + xlayers.SuperTransformerEncoderLayer( + dim, heads, False, mlp_multiplier, dropout + ) ) - layers.append(layer) - self.backbone = super_core.SuperSequential(*layers) - - # the regression head - self.head = super_core.SuperSequential( - super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) + self.backbone = xlayers.SuperSequential(*layers) + self.cls_head = xlayers.SuperSequential( + xlayers.SuperLayerNorm1D(dim), xlayers.SuperLinear(dim, num_classes) ) - trunc_normal_(self.cls_token, std=0.02) - self.apply(self._init_weights) - @property - def embed_dim(self): - return spaces.get_max(self._embed_dim) + weight_init.trunc_normal_(self.cls_token, std=0.02) + self.apply(_init_weights) @property def abstract_search_space(self): - root_node = spaces.VirtualNode(id(self)) - if not spaces.is_determined(self._embed_dim): - root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) - xdict = dict( - input_embed=self.input_embed.abstract_search_space, - pos_embed=self.pos_embed.abstract_search_space, - backbone=self.backbone.abstract_search_space, - head=self.head.abstract_search_space, - ) - for key, space in xdict.items(): - if not spaces.is_determined(space): - root_node.append(key, space) - return root_node + raise NotImplementedError def apply_candidate(self, abstract_child: spaces.VirtualNode): - super(SuperTransformer, self).apply_candidate(abstract_child) - xkeys = ("input_embed", "pos_embed", "backbone", "head") - for key in xkeys: - if key in abstract_child: - getattr(self, key).apply_candidate(abstract_child[key]) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, super_core.SuperLinear): - trunc_normal_(m._super_weight, std=0.02) - if m._super_bias is not None: - nn.init.constant_(m._super_bias, 0) - elif isinstance(m, super_core.SuperLayerNorm1D): - nn.init.constant_(m.weight, 1.0) - nn.init.constant_(m.bias, 0) + super(SuperViT, self).apply_candidate(abstract_child) + raise NotImplementedError def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: - batch, flatten_size = input.shape - feats = self.input_embed(input) # batch * 60 * 64 - if not spaces.is_determined(self._embed_dim): - embed_dim = self.abstract_child["_embed_dim"].value - else: - embed_dim = spaces.get_determined_value(self._embed_dim) - cls_tokens = self.cls_token.expand(batch, -1, -1) - cls_tokens = F.interpolate( - cls_tokens, size=(embed_dim), mode="linear", align_corners=True - ) - feats_w_ct = torch.cat((cls_tokens, feats), dim=1) - feats_w_tp = self.pos_embed(feats_w_ct) - xfeats = self.backbone(feats_w_tp) - xfeats = xfeats[:, 0, :] # use the feature for the first token - predicts = self.head(xfeats).squeeze(-1) - return predicts + raise NotImplementedError def forward_raw(self, input: torch.Tensor) -> torch.Tensor: - batch, flatten_size = input.shape - feats = self.input_embed(input) # batch * 60 * 64 + tensors = self.to_patch_embedding(input) + batch, seq, _ = tensors.shape + cls_tokens = self.cls_token.expand(batch, -1, -1) - feats_w_ct = torch.cat((cls_tokens, feats), dim=1) - feats_w_tp = self.pos_embed(feats_w_ct) - xfeats = self.backbone(feats_w_tp) - xfeats = xfeats[:, 0, :] # use the feature for the first token - predicts = self.head(xfeats).squeeze(-1) - return predicts + feats = torch.cat((cls_tokens, tensors), dim=1) + feats = feats + self.pos_embedding[:, : seq + 1, :] + feats = self.dropout(feats) + + feats = self.backbone(feats) + + x = feats[:, 0] # the features for cls-token + + return self.cls_head(x) def get_transformer(config): - if config is None: - return SuperTransformer(6) + if isinstance(config, str) and config.lower() in name2config: + config = name2config[config.lower()] if not isinstance(config, dict): raise ValueError("Invalid Configuration: {:}".format(config)) - name = config.get("name", "basic") - if name == "basic": - model = SuperTransformer( - d_feat=config.get("d_feat"), - embed_dim=config.get("embed_dim"), - num_heads=config.get("num_heads"), - mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"), - qkv_bias=config.get("qkv_bias"), - pos_drop=config.get("pos_drop"), - other_drop=config.get("other_drop"), + model_type = config.get("type", "vit").lower() + if model_type == "vit": + model = SuperViT( + image_size=config.get("image_size"), + patch_size=config.get("patch_size"), + num_classes=config.get("num_classes"), + dim=config.get("dim"), + depth=config.get("depth"), + heads=config.get("heads"), + dropout=config.get("dropout"), + emb_dropout=config.get("emb_dropout"), ) else: - raise ValueError("Unknown model name: {:}".format(name)) + raise ValueError("Unknown model type: {:}".format(model_type)) return model