Add SuperTransformer
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -130,3 +130,5 @@ TEMP-L.sh | |||||||
| .vscode | .vscode | ||||||
| mlruns | mlruns | ||||||
| outputs | outputs | ||||||
|  |  | ||||||
|  | pytest_cache | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
|   | |||||||
| @@ -1 +1,4 @@ | |||||||
| from .quant_transformer import QuantTransformer | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | from .transformers import get_transformer | ||||||
|   | |||||||
| @@ -6,236 +6,186 @@ from __future__ import print_function | |||||||
|  |  | ||||||
| import math | import math | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional, Text | from typing import Optional, Text, List | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| import xlayers | import spaces | ||||||
|  | from xlayers import trunc_normal_ | ||||||
|  | from xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
| DEFAULT_NET_CONFIG = dict( | __all__ = ["DefaultSearchSpace"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_mul_specs(candidates, num): | ||||||
|  |     results = [] | ||||||
|  |     for i in range(num): | ||||||
|  |         results.append(spaces.Categorical(*candidates)) | ||||||
|  |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_list_mul(num, multipler): | ||||||
|  |     results = [] | ||||||
|  |     for i in range(1, num + 1): | ||||||
|  |         results.append(i * multipler) | ||||||
|  |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _assert_types(x, expected_types): | ||||||
|  |     if not isinstance(x, expected_types): | ||||||
|  |         raise TypeError( | ||||||
|  |             "The type [{:}] is expected to be {:}.".format(type(x), expected_types) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _default_max_depth = 5 | ||||||
|  | DefaultSearchSpace = dict( | ||||||
|     d_feat=6, |     d_feat=6, | ||||||
|     embed_dim=64, |     stem_dim=spaces.Categorical(*_get_list_mul(8, 16)), | ||||||
|     depth=5, |     embed_dims=_get_mul_specs(_get_list_mul(8, 16), _default_max_depth), | ||||||
|     num_heads=4, |     num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth), | ||||||
|     mlp_ratio=4.0, |     mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth), | ||||||
|     qkv_bias=True, |     qkv_bias=True, | ||||||
|     pos_drop=0.0, |     pos_drop=0.0, | ||||||
|     mlp_drop_rate=0.0, |     other_drop=0.0, | ||||||
|     attn_drop_rate=0.0, |  | ||||||
|     drop_path_rate=0.0, |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Real Model | class SuperTransformer(super_core.SuperModule): | ||||||
|  |     """The super model for transformer.""" | ||||||
|  |  | ||||||
|  |  | ||||||
| class Attention(nn.Module): |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         dim, |  | ||||||
|         num_heads=8, |  | ||||||
|         qkv_bias=False, |  | ||||||
|         qk_scale=None, |  | ||||||
|         attn_drop=0.0, |  | ||||||
|         proj_drop=0.0, |  | ||||||
|     ): |  | ||||||
|         super(Attention, self).__init__() |  | ||||||
|         self.num_heads = num_heads |  | ||||||
|         head_dim = dim // num_heads |  | ||||||
|         self.scale = qk_scale or math.sqrt(head_dim) |  | ||||||
|  |  | ||||||
|         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |  | ||||||
|         self.attn_drop = nn.Dropout(attn_drop) |  | ||||||
|         self.proj = nn.Linear(dim, dim) |  | ||||||
|         self.proj_drop = nn.Dropout(proj_drop) |  | ||||||
|  |  | ||||||
|     def forward(self, x): |  | ||||||
|         B, N, C = x.shape |  | ||||||
|         qkv = ( |  | ||||||
|             self.qkv(x) |  | ||||||
|             .reshape(B, N, 3, self.num_heads, C // self.num_heads) |  | ||||||
|             .permute(2, 0, 3, 1, 4) |  | ||||||
|         ) |  | ||||||
|         q, k, v = ( |  | ||||||
|             qkv[0], |  | ||||||
|             qkv[1], |  | ||||||
|             qkv[2], |  | ||||||
|         )  # make torchscript happy (cannot use tensor as tuple) |  | ||||||
|  |  | ||||||
|         attn = (q @ k.transpose(-2, -1)) * self.scale |  | ||||||
|         attn = attn.softmax(dim=-1) |  | ||||||
|         attn = self.attn_drop(attn) |  | ||||||
|  |  | ||||||
|         x = (attn @ v).transpose(1, 2).reshape(B, N, C) |  | ||||||
|         x = self.proj(x) |  | ||||||
|         x = self.proj_drop(x) |  | ||||||
|         return x |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Block(nn.Module): |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         dim, |  | ||||||
|         num_heads, |  | ||||||
|         mlp_ratio=4.0, |  | ||||||
|         qkv_bias=False, |  | ||||||
|         qk_scale=None, |  | ||||||
|         attn_drop=0.0, |  | ||||||
|         mlp_drop=0.0, |  | ||||||
|         drop_path=0.0, |  | ||||||
|         act_layer=nn.GELU, |  | ||||||
|         norm_layer=nn.LayerNorm, |  | ||||||
|     ): |  | ||||||
|         super(Block, self).__init__() |  | ||||||
|         self.norm1 = norm_layer(dim) |  | ||||||
|         self.attn = Attention( |  | ||||||
|             dim, |  | ||||||
|             num_heads=num_heads, |  | ||||||
|             qkv_bias=qkv_bias, |  | ||||||
|             qk_scale=qk_scale, |  | ||||||
|             attn_drop=attn_drop, |  | ||||||
|             proj_drop=mlp_drop, |  | ||||||
|         ) |  | ||||||
|         # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |  | ||||||
|         self.drop_path = ( |  | ||||||
|             xlayers.DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |  | ||||||
|         ) |  | ||||||
|         self.norm2 = norm_layer(dim) |  | ||||||
|         mlp_hidden_dim = int(dim * mlp_ratio) |  | ||||||
|         self.mlp = xlayers.MLP( |  | ||||||
|             in_features=dim, |  | ||||||
|             hidden_features=mlp_hidden_dim, |  | ||||||
|             act_layer=act_layer, |  | ||||||
|             drop=mlp_drop, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def forward(self, x): |  | ||||||
|         x = x + self.drop_path(self.attn(self.norm1(x))) |  | ||||||
|         x = x + self.drop_path(self.mlp(self.norm2(x))) |  | ||||||
|         return x |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SimpleEmbed(nn.Module): |  | ||||||
|     def __init__(self, d_feat, embed_dim): |  | ||||||
|         super(SimpleEmbed, self).__init__() |  | ||||||
|         self.d_feat = d_feat |  | ||||||
|         self.embed_dim = embed_dim |  | ||||||
|         self.proj = nn.Linear(d_feat, embed_dim) |  | ||||||
|  |  | ||||||
|     def forward(self, x): |  | ||||||
|         x = x.reshape(len(x), self.d_feat, -1)  # [N, F*T] -> [N, F, T] |  | ||||||
|         x = x.permute(0, 2, 1)  # [N, F, T] -> [N, T, F] |  | ||||||
|         out = self.proj(x) * math.sqrt(self.embed_dim) |  | ||||||
|         return out |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TransformerModel(nn.Module): |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         d_feat: int = 6, |         d_feat: int = 6, | ||||||
|         embed_dim: int = 64, |         stem_dim: super_core.IntSpaceType = DefaultSearchSpace["stem_dim"], | ||||||
|         depth: int = 4, |         embed_dims: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dims"], | ||||||
|         num_heads: int = 4, |         num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], | ||||||
|         mlp_ratio: float = 4.0, |         mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ | ||||||
|         qkv_bias: bool = True, |             "mlp_hidden_multipliers" | ||||||
|         qk_scale: Optional[float] = None, |         ], | ||||||
|         pos_drop: float = 0.0, |         qkv_bias: bool = DefaultSearchSpace["qkv_bias"], | ||||||
|         mlp_drop_rate: float = 0.0, |         pos_drop: float = DefaultSearchSpace["pos_drop"], | ||||||
|         attn_drop_rate: float = 0.0, |         other_drop: float = DefaultSearchSpace["other_drop"], | ||||||
|         drop_path_rate: float = 0.0, |  | ||||||
|         norm_layer: Optional[nn.Module] = None, |  | ||||||
|         max_seq_len: int = 65, |         max_seq_len: int = 65, | ||||||
|     ): |     ): | ||||||
|         """ |         super(SuperTransformer, self).__init__() | ||||||
|         Args: |         self._embed_dims = embed_dims | ||||||
|           d_feat (int, tuple): input image size |         self._stem_dim = stem_dim | ||||||
|           embed_dim (int): embedding dimension |         self._num_heads = num_heads | ||||||
|           depth (int): depth of transformer |         self._mlp_hidden_multipliers = mlp_hidden_multipliers | ||||||
|           num_heads (int): number of attention heads |  | ||||||
|           mlp_ratio (int): ratio of mlp hidden dim to embedding dim |  | ||||||
|           qkv_bias (bool): enable bias for qkv if True |  | ||||||
|           qk_scale (float): override default qk scale of head_dim ** -0.5 if set |  | ||||||
|           pos_drop (float): dropout rate for the positional embedding |  | ||||||
|           mlp_drop_rate (float): the dropout rate for MLP layers in a block |  | ||||||
|           attn_drop_rate (float): attention dropout rate |  | ||||||
|           drop_path_rate (float): stochastic depth rate |  | ||||||
|           norm_layer: (nn.Module): normalization layer |  | ||||||
|         """ |  | ||||||
|         super(TransformerModel, self).__init__() |  | ||||||
|         self.embed_dim = embed_dim |  | ||||||
|         self.num_features = embed_dim |  | ||||||
|         norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |  | ||||||
|  |  | ||||||
|         self.input_embed = SimpleEmbed(d_feat, embed_dim=embed_dim) |         # the stem part | ||||||
|  |         self.input_embed = super_core.SuperAlphaEBDv1(d_feat, stem_dim) | ||||||
|         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.stem_dim)) | ||||||
|         self.pos_embed = xlayers.PositionalEncoder( |         self.pos_embed = super_core.SuperPositionalEncoder( | ||||||
|             d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop |             d_model=stem_dim, max_seq_len=max_seq_len, dropout=pos_drop | ||||||
|         ) |         ) | ||||||
|  |         # build the transformer encode layers -->> check params | ||||||
|         dpr = [ |         _assert_types(embed_dims, (tuple, list)) | ||||||
|             x.item() for x in torch.linspace(0, drop_path_rate, depth) |         _assert_types(num_heads, (tuple, list)) | ||||||
|         ]  # stochastic depth decay rule |         _assert_types(mlp_hidden_multipliers, (tuple, list)) | ||||||
|         self.blocks = nn.ModuleList( |         num_layers = len(embed_dims) | ||||||
|             [ |         assert ( | ||||||
|                 Block( |             num_layers == len(num_heads) == len(mlp_hidden_multipliers) | ||||||
|                     dim=embed_dim, |         ), "{:} vs {:} vs {:}".format( | ||||||
|                     num_heads=num_heads, |             num_layers, len(num_heads), len(mlp_hidden_multipliers) | ||||||
|                     mlp_ratio=mlp_ratio, |  | ||||||
|                     qkv_bias=qkv_bias, |  | ||||||
|                     qk_scale=qk_scale, |  | ||||||
|                     attn_drop=attn_drop_rate, |  | ||||||
|                     mlp_drop=mlp_drop_rate, |  | ||||||
|                     drop_path=dpr[i], |  | ||||||
|                     norm_layer=norm_layer, |  | ||||||
|                 ) |  | ||||||
|                 for i in range(depth) |  | ||||||
|             ] |  | ||||||
|         ) |         ) | ||||||
|         self.norm = norm_layer(embed_dim) |         # build the transformer encode layers -->> backbone | ||||||
|  |         layers, input_dim = [], stem_dim | ||||||
|  |         for embed_dim, num_head, mlp_hidden_multiplier in zip( | ||||||
|  |             embed_dims, num_heads, mlp_hidden_multipliers | ||||||
|  |         ): | ||||||
|  |             layer = super_core.SuperTransformerEncoderLayer( | ||||||
|  |                 input_dim, | ||||||
|  |                 embed_dim, | ||||||
|  |                 num_head, | ||||||
|  |                 qkv_bias, | ||||||
|  |                 mlp_hidden_multiplier, | ||||||
|  |                 other_drop, | ||||||
|  |             ) | ||||||
|  |             layers.append(layer) | ||||||
|  |             input_dim = embed_dim | ||||||
|  |         self.backbone = super_core.SuperSequential(*layers) | ||||||
|  |  | ||||||
|         # regression head |         # the regression head | ||||||
|         self.head = nn.Linear(self.num_features, 1) |         self.head = super_core.SuperLinear(self._embed_dims[-1], 1) | ||||||
|  |         trunc_normal_(self.cls_token, std=0.02) | ||||||
|         xlayers.trunc_normal_(self.cls_token, std=0.02) |  | ||||||
|         self.apply(self._init_weights) |         self.apply(self._init_weights) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def stem_dim(self): | ||||||
|  |         return spaces.get_max(self._stem_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         xdict = dict( | ||||||
|  |             input_embed=self.input_embed.abstract_search_space, | ||||||
|  |             pos_embed=self.pos_embed.abstract_search_space, | ||||||
|  |             backbone=self.backbone.abstract_search_space, | ||||||
|  |             head=self.head.abstract_search_space, | ||||||
|  |         ) | ||||||
|  |         if not spaces.is_determined(self._stem_dim): | ||||||
|  |             root_node.append("_stem_dim", self._stem_dim.abstract(reuse_last=True)) | ||||||
|  |         for key, space in xdict.items(): | ||||||
|  |             if not spaces.is_determined(space): | ||||||
|  |                 root_node.append(key, space) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperTransformer, self).apply_candidate(abstract_child) | ||||||
|  |         xkeys = ("input_embed", "pos_embed", "backbone", "head") | ||||||
|  |         for key in xkeys: | ||||||
|  |             if key in abstract_child: | ||||||
|  |                 getattr(self, key).apply_candidate(abstract_child[key]) | ||||||
|  |  | ||||||
|     def _init_weights(self, m): |     def _init_weights(self, m): | ||||||
|         if isinstance(m, nn.Linear): |         if isinstance(m, nn.Linear): | ||||||
|             xlayers.trunc_normal_(m.weight, std=0.02) |             trunc_normal_(m.weight, std=0.02) | ||||||
|             if isinstance(m, nn.Linear) and m.bias is not None: |             if isinstance(m, nn.Linear) and m.bias is not None: | ||||||
|                 nn.init.constant_(m.bias, 0) |                 nn.init.constant_(m.bias, 0) | ||||||
|         elif isinstance(m, nn.LayerNorm): |         elif isinstance(m, super_core.SuperLinear): | ||||||
|             nn.init.constant_(m.bias, 0) |             trunc_normal_(m._super_weight, std=0.02) | ||||||
|  |             if m._super_bias is not None: | ||||||
|  |                 nn.init.constant_(m._super_bias, 0) | ||||||
|  |         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||||
|             nn.init.constant_(m.weight, 1.0) |             nn.init.constant_(m.weight, 1.0) | ||||||
|  |             nn.init.constant_(m.bias, 0) | ||||||
|  |  | ||||||
|     def forward_features(self, x): |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         batch, flatten_size = x.shape |         batch, flatten_size = input.shape | ||||||
|         feats = self.input_embed(x)  # batch * 60 * 64 |         feats = self.input_embed(input)  # batch * 60 * 64 | ||||||
|  |         if not spaces.is_determined(self._stem_dim): | ||||||
|         cls_tokens = self.cls_token.expand( |             stem_dim = self.abstract_child["_stem_dim"].value | ||||||
|             batch, -1, -1 |         else: | ||||||
|         )  # stole cls_tokens impl from Phil Wang, thanks |             stem_dim = spaces.get_determined_value(self._stem_dim) | ||||||
|  |         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||||
|  |         cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True) | ||||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) |         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||||
|         feats_w_tp = self.pos_embed(feats_w_ct) |         feats_w_tp = self.pos_embed(feats_w_ct) | ||||||
|  |         xfeats = self.backbone(feats_w_tp) | ||||||
|  |         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||||
|  |         predicts = self.head(xfeats).squeeze(-1) | ||||||
|  |         return predicts | ||||||
|  |  | ||||||
|         xfeats = feats_w_tp |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         for block in self.blocks: |         batch, flatten_size = input.shape | ||||||
|             xfeats = block(xfeats) |         feats = self.input_embed(input)  # batch * 60 * 64 | ||||||
|  |         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||||
|         xfeats = self.norm(xfeats)[:, 0] |         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||||
|         return xfeats |         feats_w_tp = self.pos_embed(feats_w_ct) | ||||||
|  |         xfeats = self.backbone(feats_w_tp) | ||||||
|     def forward(self, x): |         xfeats = xfeats[:, 0, :]  # use the feature for the first token | ||||||
|         feats = self.forward_features(x) |         predicts = self.head(xfeats).squeeze(-1) | ||||||
|         predicts = self.head(feats).squeeze(-1) |  | ||||||
|         return predicts |         return predicts | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_transformer(config): | def get_transformer(config): | ||||||
|  |     if config is None: | ||||||
|  |         return SuperTransformer(6) | ||||||
|     if not isinstance(config, dict): |     if not isinstance(config, dict): | ||||||
|         raise ValueError("Invalid Configuration: {:}".format(config)) |         raise ValueError("Invalid Configuration: {:}".format(config)) | ||||||
|     name = config.get("name", "basic") |     name = config.get("name", "basic") | ||||||
|   | |||||||
| @@ -37,10 +37,7 @@ class SuperAttention(SuperModule): | |||||||
|         self._proj_dim = proj_dim |         self._proj_dim = proj_dim | ||||||
|         self._num_heads = num_heads |         self._num_heads = num_heads | ||||||
|         self._qkv_bias = qkv_bias |         self._qkv_bias = qkv_bias | ||||||
|         # head_dim = dim // num_heads |  | ||||||
|         # self.scale = qk_scale or math.sqrt(head_dim) |  | ||||||
|  |  | ||||||
|         # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |  | ||||||
|         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) |         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|   | |||||||
| @@ -2,6 +2,8 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
| from .super_module import SuperRunMode | from .super_module import SuperRunMode | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  |  | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_container import SuperSequential | from .super_container import SuperSequential | ||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
| @@ -9,3 +11,6 @@ from .super_linear import SuperMLPv1, SuperMLPv2 | |||||||
| from .super_norm import SuperLayerNorm1D | from .super_norm import SuperLayerNorm1D | ||||||
| from .super_attention import SuperAttention | from .super_attention import SuperAttention | ||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|  |  | ||||||
|  | from .super_trade_stem import SuperAlphaEBDv1 | ||||||
|  | from .super_positional_embedding import SuperPositionalEncoder | ||||||
|   | |||||||
| @@ -109,7 +109,7 @@ class SuperLinear(SuperModule): | |||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "in_features={:}, out_features={:}, bias={:}".format( |         return "in_features={:}, out_features={:}, bias={:}".format( | ||||||
|             self.in_features, self.out_features, self.bias |             self._in_features, self._out_features, self._bias | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -75,8 +75,10 @@ class SuperLayerNorm1D(SuperModule): | |||||||
|         return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps) |         return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps) | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "{in_dim}, eps={eps}, " "elementwise_affine={elementwise_affine}".format( |         return ( | ||||||
|             in_dim=self._in_dim, |             "shape={in_dim}, eps={eps}, elementwise_affine={elementwise_affine}".format( | ||||||
|             eps=self._eps, |                 in_dim=self._in_dim, | ||||||
|             elementwise_affine=self._elementwise_affine, |                 eps=self._eps, | ||||||
|  |                 elementwise_affine=self._elementwise_affine, | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|   | |||||||
							
								
								
									
										68
									
								
								lib/xlayers/super_positional_embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								lib/xlayers/super_positional_embedding.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,68 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
|  | ##################################################### | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import math | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  | from .super_module import SuperModule | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperPositionalEncoder(SuperModule): | ||||||
|  |     """Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf | ||||||
|  |     https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65 | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, d_model: IntSpaceType, max_seq_len: int, dropout: float = 0.1): | ||||||
|  |         super(SuperPositionalEncoder, self).__init__() | ||||||
|  |         self._d_model = d_model | ||||||
|  |         # create constant 'pe' matrix with values dependant on | ||||||
|  |         # pos and i | ||||||
|  |         self.dropout = nn.Dropout(p=dropout) | ||||||
|  |         self.register_buffer("pe", self.create_pos_embed(max_seq_len, self.d_model)) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def d_model(self): | ||||||
|  |         return spaces.get_max(self._d_model) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         if not spaces.is_determined(self._d_model): | ||||||
|  |             root_node.append("_d_model", self._d_model.abstract(reuse_last=True)) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def create_pos_embed(self, max_seq_len, d_model): | ||||||
|  |         pe = torch.zeros(max_seq_len, d_model) | ||||||
|  |         for pos in range(max_seq_len): | ||||||
|  |             for i in range(0, d_model): | ||||||
|  |                 div = 10000 ** ((i // 2) * 2 / d_model) | ||||||
|  |                 value = pos / div | ||||||
|  |                 if i % 2 == 0: | ||||||
|  |                     pe[pos, i] = math.sin(value) | ||||||
|  |                 else: | ||||||
|  |                     pe[pos, i] = math.cos(value) | ||||||
|  |         return pe.unsqueeze(0) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         batch, seq, fdim = input.shape[:3] | ||||||
|  |         embeddings = self.pe[:, :seq] | ||||||
|  |         if not spaces.is_determined(self._d_model): | ||||||
|  |             expected_d_model = self.abstract_child["_d_model"].value | ||||||
|  |         else: | ||||||
|  |             expected_d_model = spaces.get_determined_value(self._d_model) | ||||||
|  |         assert fdim == expected_d_model, "{:} vs {:}".format(fdim, expected_d_model) | ||||||
|  |  | ||||||
|  |         embeddings = torch.nn.functional.interpolate( | ||||||
|  |             embeddings, size=(expected_d_model), mode="linear", align_corners=True | ||||||
|  |         ) | ||||||
|  |         outs = self.dropout(input + embeddings) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         batch, seq, fdim = input.shape[:3] | ||||||
|  |         embeddings = self.pe[:, :seq] | ||||||
|  |         outs = self.dropout(input + embeddings) | ||||||
|  |         return outs | ||||||
							
								
								
									
										63
									
								
								lib/xlayers/super_trade_stem.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								lib/xlayers/super_trade_stem.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | from __future__ import division | ||||||
|  | from __future__ import print_function | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | from functools import partial | ||||||
|  | from typing import Optional, Text | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  | from .super_linear import SuperLinear | ||||||
|  | from .super_module import SuperModule | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperAlphaEBDv1(SuperModule): | ||||||
|  |     """A simple layer to convert the raw trading data from 1-D to 2-D data and apply an FC layer.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, d_feat: int, embed_dim: IntSpaceType): | ||||||
|  |         super(SuperAlphaEBDv1, self).__init__() | ||||||
|  |         self._d_feat = d_feat | ||||||
|  |         self._embed_dim = embed_dim | ||||||
|  |         self.proj = SuperLinear(d_feat, embed_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def embed_dim(self): | ||||||
|  |         return spaces.get_max(self._embed_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         space = self.proj.abstract_search_space | ||||||
|  |         if not spaces.is_determined(space): | ||||||
|  |             root_node.append("proj", space) | ||||||
|  |         if not spaces.is_determined(self._embed_dim): | ||||||
|  |             root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperAlphaEBDv1, self).apply_candidate(abstract_child) | ||||||
|  |         if "proj" in abstract_child: | ||||||
|  |             self.proj.apply_candidate(abstract_child["proj"]) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         x = input.reshape(len(input), self._d_feat, -1)  # [N, F*T] -> [N, F, T] | ||||||
|  |         x = x.permute(0, 2, 1)  # [N, F, T] -> [N, T, F] | ||||||
|  |         if not spaces.is_determined(self._embed_dim): | ||||||
|  |             embed_dim = self.abstract_child["_embed_dim"].value | ||||||
|  |         else: | ||||||
|  |             embed_dim = spaces.get_determined_value(self._embed_dim) | ||||||
|  |         out = self.proj(x) * math.sqrt(embed_dim) | ||||||
|  |         return out | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         x = input.reshape(len(input), self._d_feat, -1)  # [N, F*T] -> [N, F, T] | ||||||
|  |         x = x.permute(0, 2, 1)  # [N, F, T] -> [N, T, F] | ||||||
|  |         out = self.proj(x) * math.sqrt(self.embed_dim) | ||||||
|  |         return out | ||||||
| @@ -56,7 +56,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         out_features = spaces.Categorical(24, 36, 48) |         out_features = spaces.Categorical(24, 36, 48) | ||||||
|         mlp = super_core.SuperMLPv1(10, hidden_features, out_features) |         mlp = super_core.SuperMLPv1(10, hidden_features, out_features) | ||||||
|         print(mlp) |         print(mlp) | ||||||
|         mlp.apply_verbose(True) |         mlp.apply_verbose(False) | ||||||
|         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) |         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||||
| 
 | 
 | ||||||
|         inputs = torch.rand(4, 10) |         inputs = torch.rand(4, 10) | ||||||
| @@ -95,7 +95,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         out_features = spaces.Categorical(24, 36, 48) |         out_features = spaces.Categorical(24, 36, 48) | ||||||
|         mlp = super_core.SuperMLPv2(10, hidden_multiplier, out_features) |         mlp = super_core.SuperMLPv2(10, hidden_multiplier, out_features) | ||||||
|         print(mlp) |         print(mlp) | ||||||
|         mlp.apply_verbose(True) |         mlp.apply_verbose(False) | ||||||
| 
 | 
 | ||||||
|         inputs = torch.rand(4, 10) |         inputs = torch.rand(4, 10) | ||||||
|         outputs = mlp(inputs) |         outputs = mlp(inputs) | ||||||
| @@ -115,3 +115,20 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         outputs = mlp(inputs) |         outputs = mlp(inputs) | ||||||
|         output_shape = (4, abstract_child["_out_features"].value) |         output_shape = (4, abstract_child["_out_features"].value) | ||||||
|         self.assertEqual(tuple(outputs.shape), output_shape) |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|  | 
 | ||||||
|  |     def test_super_stem(self): | ||||||
|  |         out_features = spaces.Categorical(24, 36, 48) | ||||||
|  |         model = super_core.SuperAlphaEBDv1(6, out_features) | ||||||
|  |         inputs = torch.rand(4, 360) | ||||||
|  | 
 | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         abstract_space.clean_last() | ||||||
|  |         abstract_child = abstract_space.random(reuse_last=True) | ||||||
|  |         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||||
|  |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  | 
 | ||||||
|  |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.apply_candidate(abstract_child) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         output_shape = (4, 60, abstract_child["_embed_dim"].value) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
							
								
								
									
										44
									
								
								tests/test_super_transformer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								tests/test_super_transformer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest ./tests/test_super_model.py -s             # | ||||||
|  | ##################################################### | ||||||
|  | import sys, random | ||||||
|  | import unittest | ||||||
|  | import pytest | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||||
|  | print("library path: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from xlayers.super_core import SuperRunMode | ||||||
|  | from trade_models import get_transformer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestSuperTransformer(unittest.TestCase): | ||||||
|  |     """Test the super transformer.""" | ||||||
|  |  | ||||||
|  |     def test_super_transformer(self): | ||||||
|  |         model = get_transformer(None) | ||||||
|  |         model.apply_verbose(False) | ||||||
|  |         print(model) | ||||||
|  |  | ||||||
|  |         inputs = torch.rand(10, 360) | ||||||
|  |         print("Input shape: {:}".format(inputs.shape)) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), (10,)) | ||||||
|  |  | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         abstract_space.clean_last() | ||||||
|  |         abstract_child = abstract_space.random(reuse_last=True) | ||||||
|  |         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||||
|  |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|  |         model.set_super_run_type(SuperRunMode.Candidate) | ||||||
|  |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), (10,)) | ||||||
		Reference in New Issue
	
	Block a user