Update SuperViT
This commit is contained in:
parent
0ddc5c0dc4
commit
d4546cfe3f
@ -10,20 +10,31 @@ import torch
|
||||
from xautodl.xmodels import transformers
|
||||
from xautodl.utils.flop_benchmark import count_parameters
|
||||
|
||||
|
||||
class TestSuperViT(unittest.TestCase):
|
||||
"""Test the super re-arrange layer."""
|
||||
|
||||
def test_super_vit(self):
|
||||
model = transformers.get_transformer("vit-base")
|
||||
tensor = torch.rand((16, 3, 256, 256))
|
||||
model = transformers.get_transformer("vit-base-16")
|
||||
tensor = torch.rand((16, 3, 224, 224))
|
||||
print("The tensor shape: {:}".format(tensor.shape))
|
||||
print(model)
|
||||
# print(model)
|
||||
outs = model(tensor)
|
||||
print("The output tensor shape: {:}".format(outs.shape))
|
||||
|
||||
def test_model_size(self):
|
||||
def test_imagenet(self):
|
||||
name2config = transformers.name2config
|
||||
print("There are {:} models in total.".format(len(name2config)))
|
||||
for name, config in name2config.items():
|
||||
if "cifar" in name:
|
||||
tensor = torch.rand((16, 3, 32, 32))
|
||||
else:
|
||||
tensor = torch.rand((16, 3, 224, 224))
|
||||
model = transformers.get_transformer(config)
|
||||
outs = model(tensor)
|
||||
size = count_parameters(model, "mb", True)
|
||||
print('{:10s} : size={:.2f}MB'.format(name, size))
|
||||
print(
|
||||
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
|
||||
name, size, tuple(outs.shape)
|
||||
)
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ from xautodl import spaces
|
||||
from .super_module import SuperModule
|
||||
from .super_module import IntSpaceType
|
||||
from .super_module import BoolSpaceType
|
||||
from .super_dropout import SuperDropout, SuperDrop
|
||||
from .super_linear import SuperLinear
|
||||
|
||||
|
||||
@ -22,7 +23,7 @@ class SuperSelfAttention(SuperModule):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: IntSpaceType,
|
||||
proj_dim: IntSpaceType,
|
||||
proj_dim: Optional[IntSpaceType],
|
||||
num_heads: IntSpaceType,
|
||||
qkv_bias: BoolSpaceType = False,
|
||||
attn_drop: Optional[float] = None,
|
||||
@ -37,13 +38,17 @@ class SuperSelfAttention(SuperModule):
|
||||
self._use_mask = use_mask
|
||||
self._infinity = 1e9
|
||||
|
||||
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
mul_head_dim = (input_dim // num_heads) * num_heads
|
||||
self.q_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
|
||||
self.k_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
|
||||
self.v_fc = SuperLinear(input_dim, mul_head_dim, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop or 0.0)
|
||||
self.proj = SuperLinear(input_dim, proj_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop or 0.0)
|
||||
self.attn_drop = SuperDrop(attn_drop, [-1, -1, -1, -1], recover=True)
|
||||
if proj_dim is None:
|
||||
self.proj = SuperLinear(input_dim, proj_dim)
|
||||
self.proj_drop = SuperDropout(proj_drop or 0.0)
|
||||
else:
|
||||
self.proj = None
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
@ -63,7 +68,6 @@ class SuperSelfAttention(SuperModule):
|
||||
space_q = self.q_fc.abstract_search_space
|
||||
space_k = self.k_fc.abstract_search_space
|
||||
space_v = self.v_fc.abstract_search_space
|
||||
space_proj = self.proj.abstract_search_space
|
||||
if not spaces.is_determined(self._num_heads):
|
||||
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
|
||||
if not spaces.is_determined(space_q):
|
||||
@ -72,8 +76,10 @@ class SuperSelfAttention(SuperModule):
|
||||
root_node.append("k_fc", space_k)
|
||||
if not spaces.is_determined(space_v):
|
||||
root_node.append("v_fc", space_v)
|
||||
if not spaces.is_determined(space_proj):
|
||||
root_node.append("proj", space_proj)
|
||||
if self.proj is not None:
|
||||
space_proj = self.proj.abstract_search_space
|
||||
if not spaces.is_determined(space_proj):
|
||||
root_node.append("proj", space_proj)
|
||||
return root_node
|
||||
|
||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||
@ -121,18 +127,7 @@ class SuperSelfAttention(SuperModule):
|
||||
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
|
||||
attn_v1 = self.attn_drop(attn_v1)
|
||||
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
||||
if C == head_dim * num_head:
|
||||
feats = feats_v1
|
||||
else: # The channels can not be divided by num_head, the remainder forms an additional head
|
||||
q_v2 = q[:, :, num_head * head_dim :]
|
||||
k_v2 = k[:, :, num_head * head_dim :]
|
||||
v_v2 = v[:, :, num_head * head_dim :]
|
||||
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
|
||||
attn_v2 = attn_v2.softmax(dim=-1)
|
||||
attn_v2 = self.attn_drop(attn_v2)
|
||||
feats_v2 = attn_v2 @ v_v2
|
||||
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
||||
return feats
|
||||
return feats_v1
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# check the num_heads:
|
||||
@ -141,15 +136,21 @@ class SuperSelfAttention(SuperModule):
|
||||
else:
|
||||
num_heads = spaces.get_determined_value(self._num_heads)
|
||||
feats = self.forward_qkv(input, num_heads)
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
if self.proj is None:
|
||||
return feats
|
||||
else:
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
feats = self.forward_qkv(input, self.num_heads)
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
if self.proj is None:
|
||||
return feats
|
||||
else:
|
||||
outs = self.proj(feats)
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
|
@ -37,7 +37,8 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
num_heads: IntSpaceType,
|
||||
qkv_bias: BoolSpaceType = False,
|
||||
mlp_hidden_multiplier: IntSpaceType = 4,
|
||||
drop: Optional[float] = None,
|
||||
dropout: Optional[float] = None,
|
||||
att_dropout: Optional[float] = None,
|
||||
norm_affine: bool = True,
|
||||
act_layer: Callable[[], nn.Module] = nn.GELU,
|
||||
order: LayerOrder = LayerOrder.PreNorm,
|
||||
@ -49,8 +50,8 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
d_model,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=drop,
|
||||
proj_drop=drop,
|
||||
attn_drop=att_dropout,
|
||||
proj_drop=None,
|
||||
use_mask=use_mask,
|
||||
)
|
||||
mlp = SuperMLPv2(
|
||||
@ -58,21 +59,20 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
hidden_multiplier=mlp_hidden_multiplier,
|
||||
out_features=d_model,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=dropout,
|
||||
)
|
||||
if order is LayerOrder.PreNorm:
|
||||
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mha = mha
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.drop = nn.Dropout(dropout or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mlp = mlp
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
elif order is LayerOrder.PostNorm:
|
||||
self.mha = mha
|
||||
self.drop1 = nn.Dropout(drop or 0.0)
|
||||
self.drop1 = nn.Dropout(dropout or 0.0)
|
||||
self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
self.mlp = mlp
|
||||
self.drop2 = nn.Dropout(drop or 0.0)
|
||||
self.drop2 = nn.Dropout(dropout or 0.0)
|
||||
self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(order))
|
||||
@ -99,23 +99,29 @@ class SuperTransformerEncoderLayer(SuperModule):
|
||||
if key in abstract_child:
|
||||
getattr(self, key).apply_candidate(abstract_child[key])
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(input)
|
||||
def forward_candidate(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_raw(inputs)
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
def forward_raw(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if self._order is LayerOrder.PreNorm:
|
||||
x = self.norm1(input)
|
||||
x = x + self.drop1(self.mha(x))
|
||||
x = self.norm2(x)
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/models.py#L135
|
||||
x = self.norm1(inputs)
|
||||
x = self.mha(x)
|
||||
x = self.drop(x)
|
||||
x = x + inputs
|
||||
# feed-forward layer -- MLP
|
||||
y = self.norm2(x)
|
||||
outs = x + self.mlp(y)
|
||||
elif self._order is LayerOrder.PostNorm:
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoder
|
||||
# multi-head attention
|
||||
x = self.mha(input)
|
||||
x = x + self.drop1(x)
|
||||
x = self.mha(inputs)
|
||||
x = inputs + self.drop1(x)
|
||||
x = self.norm1(x)
|
||||
# feed-forward layer
|
||||
x = x + self.drop2(self.mlp(x))
|
||||
x = self.norm2(x)
|
||||
# feed-forward layer -- MLP
|
||||
y = self.mlp(x)
|
||||
y = x + self.drop2(y)
|
||||
outs = self.norm2(y)
|
||||
else:
|
||||
raise ValueError("Unknown order: {:}".format(self._order))
|
||||
return x
|
||||
return outs
|
||||
|
@ -3,7 +3,7 @@
|
||||
#####################################################
|
||||
# Vision Transformer: arxiv.org/pdf/2010.11929.pdf #
|
||||
#####################################################
|
||||
import math
|
||||
import copy, math
|
||||
from functools import partial
|
||||
from typing import Optional, Text, List
|
||||
|
||||
@ -35,42 +35,69 @@ def _init_weights(m):
|
||||
|
||||
|
||||
name2config = {
|
||||
"vit-base": dict(
|
||||
"vit-cifar10-p4-d4-h4-c32": dict(
|
||||
type="vit",
|
||||
image_size=256,
|
||||
image_size=32,
|
||||
patch_size=4,
|
||||
num_classes=10,
|
||||
dim=32,
|
||||
depth=4,
|
||||
heads=4,
|
||||
dropout=0.1,
|
||||
att_dropout=0.0,
|
||||
),
|
||||
"vit-base-16": dict(
|
||||
type="vit",
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
num_classes=1000,
|
||||
dim=768,
|
||||
depth=12,
|
||||
heads=12,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1,
|
||||
att_dropout=0.0,
|
||||
),
|
||||
"vit-large": dict(
|
||||
"vit-large-16": dict(
|
||||
type="vit",
|
||||
image_size=256,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
num_classes=1000,
|
||||
dim=1024,
|
||||
depth=24,
|
||||
heads=16,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1,
|
||||
att_dropout=0.0,
|
||||
),
|
||||
"vit-huge": dict(
|
||||
"vit-huge-14": dict(
|
||||
type="vit",
|
||||
image_size=256,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
num_classes=1000,
|
||||
dim=1280,
|
||||
depth=32,
|
||||
heads=16,
|
||||
dropout=0.1,
|
||||
emb_dropout=0.1,
|
||||
att_dropout=0.0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def extend_cifar100(configs):
|
||||
new_configs = dict()
|
||||
for name, config in configs.items():
|
||||
new_configs[name] = config
|
||||
if "cifar10" in name and "cifar100" not in name:
|
||||
config = copy.deepcopy(config)
|
||||
config["num_classes"] = 100
|
||||
a, b = name.split("cifar10")
|
||||
new_name = "{:}cifar100{:}".format(a, b)
|
||||
new_configs[new_name] = config
|
||||
return new_configs
|
||||
|
||||
|
||||
name2config = extend_cifar100(name2config)
|
||||
|
||||
|
||||
class SuperViT(xlayers.SuperModule):
|
||||
"""The super model for transformer."""
|
||||
|
||||
@ -85,7 +112,7 @@ class SuperViT(xlayers.SuperModule):
|
||||
mlp_multiplier=4,
|
||||
channels=3,
|
||||
dropout=0.0,
|
||||
emb_dropout=0.0,
|
||||
att_dropout=0.0,
|
||||
):
|
||||
super(SuperViT, self).__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@ -107,14 +134,19 @@ class SuperViT(xlayers.SuperModule):
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# build the transformer encode layers
|
||||
layers = []
|
||||
for ilayer in range(depth):
|
||||
layers.append(
|
||||
xlayers.SuperTransformerEncoderLayer(
|
||||
dim, heads, False, mlp_multiplier, dropout
|
||||
dim,
|
||||
heads,
|
||||
False,
|
||||
mlp_multiplier,
|
||||
dropout=dropout,
|
||||
att_dropout=att_dropout,
|
||||
)
|
||||
)
|
||||
self.backbone = xlayers.SuperSequential(*layers)
|
||||
@ -167,7 +199,7 @@ def get_transformer(config):
|
||||
depth=config.get("depth"),
|
||||
heads=config.get("heads"),
|
||||
dropout=config.get("dropout"),
|
||||
emb_dropout=config.get("emb_dropout"),
|
||||
att_dropout=config.get("att_dropout"),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown model type: {:}".format(model_type))
|
||||
|
Loading…
Reference in New Issue
Block a user