From bc42ab3c083a3ceadaf590f3f3be4d4dff78691d Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 22 May 2021 16:41:54 +0800 Subject: [PATCH] Fix bugs in xlayers --- exps/LFNA/lfna.py | 3 +- exps/LFNA/lfna_meta_model.py | 54 +++--- setup.py | 2 +- xautodl/xlayers/super_attention.py | 158 +++++++++++++++++- xautodl/xlayers/super_core.py | 3 +- xautodl/xlayers/super_positional_embedding.py | 12 +- xautodl/xlayers/super_transformer.py | 4 +- 7 files changed, 197 insertions(+), 39 deletions(-) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 2ab32cb..e9b5abf 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -10,7 +10,8 @@ from tqdm import tqdm from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / "..").resolve() +lib_dir = (Path(__file__).parent / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index e80a75a..66fbcd5 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): layer_embedding, time_embedding, meta_timestamps, - mha_depth: int = 2, + mha_depth: int = 1, dropout: float = 0.1, ): super(LFNA_Meta, self).__init__() @@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule): self._append_meta_embed = dict(fixed=None, learnt=None) self._append_meta_timestamps = dict(fixed=None, learnt=None) - self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False) + self._tscalar_embed = super_core.SuperDynamicPositionE( + time_embedding, scale=100 + ) + # build transformer + self._trans_att = super_core.SuperQKVAttention( + time_embedding, + time_embedding, + time_embedding, + time_embedding, + 4, + True, + attn_drop=None, + proj_drop=dropout, + ) layers = [] for ilayer in range(mha_depth): layers.append( @@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule): self._generator = get_model(**model_kwargs) # print("generator: {:}".format(self._generator)) - # unknown token - self.register_parameter( - "_unknown_token", - torch.nn.Parameter(torch.Tensor(1, time_embedding)), - ) - # initialization trunc_normal_( - [self._super_layer_embed, self._super_meta_embed, self._unknown_token], + [self._super_layer_embed, self._super_meta_embed], std=0.02, ) @@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule): (self._append_meta_embed["fixed"], meta_embed), dim=0 ) - def forward_raw(self, timestamps): + def _obtain_time_embed(self, timestamps): # timestamps is a batch of sequence of timestamps batch, seq = timestamps.shape - timestamps = timestamps.unsqueeze(dim=-1) - meta_timestamps = self.meta_timestamps.view(1, 1, -1) - time_diffs = timestamps - meta_timestamps - time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) - # select corresponding meta-knowledge - meta_match = torch.index_select( - self.super_meta_embed, dim=0, index=time_match_i.view(-1) + timestamp_q_embed = self._tscalar_embed(timestamps) + timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1)) + timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0) + timestamp_embeds = self._trans_att( + timestamp_q_embed, timestamp_k_embed, timestamp_v_embed ) - meta_match = meta_match.view(batch, seq, -1) - # create the probability - time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) + corrected_embeds = self.meta_corrector(timestamp_embeds) + return corrected_embeds - x_time_probs = self._time_prob_drop(time_probs) - # if self.training: - # time_probs[:, -1, :] = 0 - unknown_token = self._unknown_token.view(1, 1, -1) - raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token - - meta_embed = self.meta_corrector(raw_meta_embed) + def forward_raw(self, timestamps): + batch, seq = timestamps.shape + meta_embed = self._obtain_time_embed(timestamps) # create joint embed num_layer, _ = self._super_layer_embed.shape meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) diff --git a/setup.py b/setup.py index 3021cf7..c83cbdb 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ # # TODO(xuanyidong): upload it to conda # -# [2021.05.18] v1.0 +# [2021.05.21] v0.9.9 import os from setuptools import setup, find_packages diff --git a/xautodl/xlayers/super_attention.py b/xautodl/xlayers/super_attention.py index 8c913a5..924cdd0 100644 --- a/xautodl/xlayers/super_attention.py +++ b/xautodl/xlayers/super_attention.py @@ -20,7 +20,7 @@ from .super_module import BoolSpaceType from .super_linear import SuperLinear -class SuperAttention(SuperModule): +class SuperSelfAttention(SuperModule): """The super model for attention layer.""" def __init__( @@ -32,7 +32,7 @@ class SuperAttention(SuperModule): attn_drop: Optional[float] = None, proj_drop: Optional[float] = None, ): - super(SuperAttention, self).__init__() + super(SuperSelfAttention, self).__init__() self._input_dim = input_dim self._proj_dim = proj_dim self._num_heads = num_heads @@ -150,3 +150,157 @@ class SuperAttention(SuperModule): return "input_dim={:}, proj_dim={:}, num_heads={:}".format( self._input_dim, self._proj_dim, self._num_heads ) + + +class SuperQKVAttention(SuperModule): + """The super model for attention layer.""" + + def __init__( + self, + in_q_dim: IntSpaceType, + in_k_dim: IntSpaceType, + in_v_dim: IntSpaceType, + proj_dim: IntSpaceType, + num_heads: IntSpaceType, + qkv_bias: BoolSpaceType = False, + attn_drop: Optional[float] = None, + proj_drop: Optional[float] = None, + ): + super(SuperQKVAttention, self).__init__() + self._in_v_dim = in_v_dim + self._in_q_dim = in_q_dim + self._in_k_dim = in_k_dim + self._proj_dim = proj_dim + self._num_heads = num_heads + self._qkv_bias = qkv_bias + + self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias) + self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias) + self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop or 0.0) + self.proj = SuperLinear(proj_dim, proj_dim) + self.proj_drop = nn.Dropout(proj_drop or 0.0) + + @property + def num_heads(self): + return spaces.get_max(self._num_heads) + + @property + def in_v_dim(self): + return spaces.get_max(self._in_v_dim) + + @property + def in_q_dim(self): + return spaces.get_max(self._in_q_dim) + + @property + def in_k_dim(self): + return spaces.get_max(self._in_k_dim) + + @property + def proj_dim(self): + return spaces.get_max(self._proj_dim) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + space_q = self.q_fc.abstract_search_space + space_k = self.k_fc.abstract_search_space + space_v = self.v_fc.abstract_search_space + space_proj = self.proj.abstract_search_space + if not spaces.is_determined(self._num_heads): + root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) + if not spaces.is_determined(space_q): + root_node.append("q_fc", space_q) + if not spaces.is_determined(space_k): + root_node.append("k_fc", space_k) + if not spaces.is_determined(space_v): + root_node.append("v_fc", space_v) + if not spaces.is_determined(space_proj): + root_node.append("proj", space_proj) + return root_node + + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperAttention, self).apply_candidate(abstract_child) + if "q_fc" in abstract_child: + self.q_fc.apply_candidate(abstract_child["q_fc"]) + if "k_fc" in abstract_child: + self.k_fc.apply_candidate(abstract_child["k_fc"]) + if "v_fc" in abstract_child: + self.v_fc.apply_candidate(abstract_child["v_fc"]) + if "proj" in abstract_child: + self.proj.apply_candidate(abstract_child["proj"]) + + def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor: + q = self.q_fc(q_tensor) + B, N, C = q.shape + + k = self.k_fc(k_tensor) + B0, S, _ = k.shape + + v = self.v_fc(v_tensor) + assert B0 == v.shape[0] and S == v.shape[1] + + head_dim = C // num_head + if num_head > C: + raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C)) + q_v1 = ( + q[:, :, : num_head * head_dim] + .reshape(B, N, num_head, head_dim) + .permute(0, 2, 1, 3) + ) + k_v1 = ( + k[:, :, : num_head * head_dim] + .reshape(B0, S, num_head, head_dim) + .permute(0, 2, 1, 3) + ) + # compute the attention map + attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) + attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S + attn_v1 = self.attn_drop(attn_v1) + + v_v1 = ( + v[:, :, : num_head * head_dim] + .reshape(B0, S, num_head, head_dim) + .permute(0, 2, 1, 3) + ) + feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) + # process the first [num_head * head_dim] part + if C == head_dim * num_head: + feats = feats_v1 + else: # The channels can not be divided by num_head, the remainder forms an additional head + # [might have bugs, did not check yet] + q_v2 = q[:, :, num_head * head_dim :] + k_v2 = k[:, :, num_head * head_dim :] + v_v2 = v[:, :, num_head * head_dim :] + attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) + attn_v2 = attn_v2.softmax(dim=-1) + attn_v2 = self.attn_drop(attn_v2) + feats_v2 = attn_v2 @ v_v2 + feats = torch.cat([feats_v1, feats_v2], dim=-1) + return feats + + def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: + # check the num_heads: + if not spaces.is_determined(self._num_heads): + num_heads = self.abstract_child["_num_heads"].value + else: + num_heads = spaces.get_determined_value(self._num_heads) + feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads) + outs = self.proj(feats) + outs = self.proj_drop(outs) + return outs + + def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: + feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads) + outs = self.proj(feats) + outs = self.proj_drop(outs) + return outs + + def extra_repr(self) -> str: + return "input_dim={:}, proj_dim={:}, num_heads={:}".format( + (self.in_q_dim, self.in_k_dim, self.in_v_dim), + self._proj_dim, + self._num_heads, + ) diff --git a/xautodl/xlayers/super_core.py b/xautodl/xlayers/super_core.py index 0055f95..4544280 100644 --- a/xautodl/xlayers/super_core.py +++ b/xautodl/xlayers/super_core.py @@ -24,7 +24,8 @@ super_name2norm = { "identity": SuperIdentity, } -from .super_attention import SuperAttention +from .super_attention import SuperSelfAttention +from .super_attention import SuperQKVAttention from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU diff --git a/xautodl/xlayers/super_positional_embedding.py b/xautodl/xlayers/super_positional_embedding.py index d1b013d..4ee7f28 100644 --- a/xautodl/xlayers/super_positional_embedding.py +++ b/xautodl/xlayers/super_positional_embedding.py @@ -35,11 +35,13 @@ class SuperDynamicPositionE(SuperModule): return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: - import pdb - - pdb.set_trace() - print("---") - return F.linear(input, self._super_weight, self._super_bias) + positions = torch.unsqueeze(input * self._scale, dim=-1) + divisions = torch.reshape( + self._div_term, [1] * input.ndim + [self._div_term.numel()] + ) + values = positions / divisions + embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1) + return embeds def extra_repr(self) -> str: return "scale={:}, dim={:}".format(self._scale, self._dimension) diff --git a/xautodl/xlayers/super_transformer.py b/xautodl/xlayers/super_transformer.py index 793d04e..ef43879 100644 --- a/xautodl/xlayers/super_transformer.py +++ b/xautodl/xlayers/super_transformer.py @@ -19,7 +19,7 @@ from .super_module import LayerOrder from .super_module import SuperModule from .super_linear import SuperMLPv2 from .super_norm import SuperLayerNorm1D -from .super_attention import SuperAttention +from .super_attention import SuperSelfAttention class SuperTransformerEncoderLayer(SuperModule): @@ -47,7 +47,7 @@ class SuperTransformerEncoderLayer(SuperModule): order: LayerOrder = LayerOrder.PreNorm, ): super(SuperTransformerEncoderLayer, self).__init__() - mha = SuperAttention( + mha = SuperSelfAttention( d_model, d_model, num_heads=num_heads,