Fix bugs in xlayers
This commit is contained in:
		| @@ -10,7 +10,8 @@ from tqdm import tqdm | |||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / "..").resolve() | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         layer_embedding, |         layer_embedding, | ||||||
|         time_embedding, |         time_embedding, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         mha_depth: int = 2, |         mha_depth: int = 1, | ||||||
|         dropout: float = 0.1, |         dropout: float = 0.1, | ||||||
|     ): |     ): | ||||||
|         super(LFNA_Meta, self).__init__() |         super(LFNA_Meta, self).__init__() | ||||||
| @@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
|  |  | ||||||
|         self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False) |         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||||
|  |             time_embedding, scale=100 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         # build transformer |         # build transformer | ||||||
|  |         self._trans_att = super_core.SuperQKVAttention( | ||||||
|  |             time_embedding, | ||||||
|  |             time_embedding, | ||||||
|  |             time_embedding, | ||||||
|  |             time_embedding, | ||||||
|  |             4, | ||||||
|  |             True, | ||||||
|  |             attn_drop=None, | ||||||
|  |             proj_drop=dropout, | ||||||
|  |         ) | ||||||
|         layers = [] |         layers = [] | ||||||
|         for ilayer in range(mha_depth): |         for ilayer in range(mha_depth): | ||||||
|             layers.append( |             layers.append( | ||||||
| @@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         self._generator = get_model(**model_kwargs) |         self._generator = get_model(**model_kwargs) | ||||||
|         # print("generator: {:}".format(self._generator)) |         # print("generator: {:}".format(self._generator)) | ||||||
|  |  | ||||||
|         # unknown token |  | ||||||
|         self.register_parameter( |  | ||||||
|             "_unknown_token", |  | ||||||
|             torch.nn.Parameter(torch.Tensor(1, time_embedding)), |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # initialization |         # initialization | ||||||
|         trunc_normal_( |         trunc_normal_( | ||||||
|             [self._super_layer_embed, self._super_meta_embed, self._unknown_token], |             [self._super_layer_embed, self._super_meta_embed], | ||||||
|             std=0.02, |             std=0.02, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 |                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps): |     def _obtain_time_embed(self, timestamps): | ||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         timestamps = timestamps.unsqueeze(dim=-1) |         timestamp_q_embed = self._tscalar_embed(timestamps) | ||||||
|         meta_timestamps = self.meta_timestamps.view(1, 1, -1) |         timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1)) | ||||||
|         time_diffs = timestamps - meta_timestamps |         timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0) | ||||||
|         time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) |         timestamp_embeds = self._trans_att( | ||||||
|         # select corresponding meta-knowledge |             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed | ||||||
|         meta_match = torch.index_select( |  | ||||||
|             self.super_meta_embed, dim=0, index=time_match_i.view(-1) |  | ||||||
|         ) |         ) | ||||||
|         meta_match = meta_match.view(batch, seq, -1) |         corrected_embeds = self.meta_corrector(timestamp_embeds) | ||||||
|         # create the probability |         return corrected_embeds | ||||||
|         time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) |  | ||||||
|  |  | ||||||
|         x_time_probs = self._time_prob_drop(time_probs) |     def forward_raw(self, timestamps): | ||||||
|         # if self.training: |         batch, seq = timestamps.shape | ||||||
|         #    time_probs[:, -1, :] = 0 |         meta_embed = self._obtain_time_embed(timestamps) | ||||||
|         unknown_token = self._unknown_token.view(1, 1, -1) |  | ||||||
|         raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token |  | ||||||
|  |  | ||||||
|         meta_embed = self.meta_corrector(raw_meta_embed) |  | ||||||
|         # create joint embed |         # create joint embed | ||||||
|         num_layer, _ = self._super_layer_embed.shape |         num_layer, _ = self._super_layer_embed.shape | ||||||
|         meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) |         meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -16,7 +16,7 @@ | |||||||
| # | # | ||||||
| # TODO(xuanyidong): upload it to conda | # TODO(xuanyidong): upload it to conda | ||||||
| # | # | ||||||
| # [2021.05.18] v1.0 | # [2021.05.21] v0.9.9 | ||||||
| import os | import os | ||||||
| from setuptools import setup, find_packages | from setuptools import setup, find_packages | ||||||
|  |  | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ from .super_module import BoolSpaceType | |||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperAttention(SuperModule): | class SuperSelfAttention(SuperModule): | ||||||
|     """The super model for attention layer.""" |     """The super model for attention layer.""" | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
| @@ -32,7 +32,7 @@ class SuperAttention(SuperModule): | |||||||
|         attn_drop: Optional[float] = None, |         attn_drop: Optional[float] = None, | ||||||
|         proj_drop: Optional[float] = None, |         proj_drop: Optional[float] = None, | ||||||
|     ): |     ): | ||||||
|         super(SuperAttention, self).__init__() |         super(SuperSelfAttention, self).__init__() | ||||||
|         self._input_dim = input_dim |         self._input_dim = input_dim | ||||||
|         self._proj_dim = proj_dim |         self._proj_dim = proj_dim | ||||||
|         self._num_heads = num_heads |         self._num_heads = num_heads | ||||||
| @@ -150,3 +150,157 @@ class SuperAttention(SuperModule): | |||||||
|         return "input_dim={:}, proj_dim={:}, num_heads={:}".format( |         return "input_dim={:}, proj_dim={:}, num_heads={:}".format( | ||||||
|             self._input_dim, self._proj_dim, self._num_heads |             self._input_dim, self._proj_dim, self._num_heads | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperQKVAttention(SuperModule): | ||||||
|  |     """The super model for attention layer.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_q_dim: IntSpaceType, | ||||||
|  |         in_k_dim: IntSpaceType, | ||||||
|  |         in_v_dim: IntSpaceType, | ||||||
|  |         proj_dim: IntSpaceType, | ||||||
|  |         num_heads: IntSpaceType, | ||||||
|  |         qkv_bias: BoolSpaceType = False, | ||||||
|  |         attn_drop: Optional[float] = None, | ||||||
|  |         proj_drop: Optional[float] = None, | ||||||
|  |     ): | ||||||
|  |         super(SuperQKVAttention, self).__init__() | ||||||
|  |         self._in_v_dim = in_v_dim | ||||||
|  |         self._in_q_dim = in_q_dim | ||||||
|  |         self._in_k_dim = in_k_dim | ||||||
|  |         self._proj_dim = proj_dim | ||||||
|  |         self._num_heads = num_heads | ||||||
|  |         self._qkv_bias = qkv_bias | ||||||
|  |  | ||||||
|  |         self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias) | ||||||
|  |         self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias) | ||||||
|  |         self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias) | ||||||
|  |  | ||||||
|  |         self.attn_drop = nn.Dropout(attn_drop or 0.0) | ||||||
|  |         self.proj = SuperLinear(proj_dim, proj_dim) | ||||||
|  |         self.proj_drop = nn.Dropout(proj_drop or 0.0) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def num_heads(self): | ||||||
|  |         return spaces.get_max(self._num_heads) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_v_dim(self): | ||||||
|  |         return spaces.get_max(self._in_v_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_q_dim(self): | ||||||
|  |         return spaces.get_max(self._in_q_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_k_dim(self): | ||||||
|  |         return spaces.get_max(self._in_k_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def proj_dim(self): | ||||||
|  |         return spaces.get_max(self._proj_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         space_q = self.q_fc.abstract_search_space | ||||||
|  |         space_k = self.k_fc.abstract_search_space | ||||||
|  |         space_v = self.v_fc.abstract_search_space | ||||||
|  |         space_proj = self.proj.abstract_search_space | ||||||
|  |         if not spaces.is_determined(self._num_heads): | ||||||
|  |             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) | ||||||
|  |         if not spaces.is_determined(space_q): | ||||||
|  |             root_node.append("q_fc", space_q) | ||||||
|  |         if not spaces.is_determined(space_k): | ||||||
|  |             root_node.append("k_fc", space_k) | ||||||
|  |         if not spaces.is_determined(space_v): | ||||||
|  |             root_node.append("v_fc", space_v) | ||||||
|  |         if not spaces.is_determined(space_proj): | ||||||
|  |             root_node.append("proj", space_proj) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperAttention, self).apply_candidate(abstract_child) | ||||||
|  |         if "q_fc" in abstract_child: | ||||||
|  |             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||||
|  |         if "k_fc" in abstract_child: | ||||||
|  |             self.k_fc.apply_candidate(abstract_child["k_fc"]) | ||||||
|  |         if "v_fc" in abstract_child: | ||||||
|  |             self.v_fc.apply_candidate(abstract_child["v_fc"]) | ||||||
|  |         if "proj" in abstract_child: | ||||||
|  |             self.proj.apply_candidate(abstract_child["proj"]) | ||||||
|  |  | ||||||
|  |     def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor: | ||||||
|  |         q = self.q_fc(q_tensor) | ||||||
|  |         B, N, C = q.shape | ||||||
|  |  | ||||||
|  |         k = self.k_fc(k_tensor) | ||||||
|  |         B0, S, _ = k.shape | ||||||
|  |  | ||||||
|  |         v = self.v_fc(v_tensor) | ||||||
|  |         assert B0 == v.shape[0] and S == v.shape[1] | ||||||
|  |  | ||||||
|  |         head_dim = C // num_head | ||||||
|  |         if num_head > C: | ||||||
|  |             raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C)) | ||||||
|  |         q_v1 = ( | ||||||
|  |             q[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B, N, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         k_v1 = ( | ||||||
|  |             k[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B0, S, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         # compute the attention map | ||||||
|  |         attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) | ||||||
|  |         attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * S | ||||||
|  |         attn_v1 = self.attn_drop(attn_v1) | ||||||
|  |  | ||||||
|  |         v_v1 = ( | ||||||
|  |             v[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B0, S, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||||
|  |         # process the first [num_head * head_dim] part | ||||||
|  |         if C == head_dim * num_head: | ||||||
|  |             feats = feats_v1 | ||||||
|  |         else:  # The channels can not be divided by num_head, the remainder forms an additional head | ||||||
|  |             # [might have bugs, did not check yet] | ||||||
|  |             q_v2 = q[:, :, num_head * head_dim :] | ||||||
|  |             k_v2 = k[:, :, num_head * head_dim :] | ||||||
|  |             v_v2 = v[:, :, num_head * head_dim :] | ||||||
|  |             attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) | ||||||
|  |             attn_v2 = attn_v2.softmax(dim=-1) | ||||||
|  |             attn_v2 = self.attn_drop(attn_v2) | ||||||
|  |             feats_v2 = attn_v2 @ v_v2 | ||||||
|  |             feats = torch.cat([feats_v1, feats_v2], dim=-1) | ||||||
|  |         return feats | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: | ||||||
|  |         # check the num_heads: | ||||||
|  |         if not spaces.is_determined(self._num_heads): | ||||||
|  |             num_heads = self.abstract_child["_num_heads"].value | ||||||
|  |         else: | ||||||
|  |             num_heads = spaces.get_determined_value(self._num_heads) | ||||||
|  |         feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads) | ||||||
|  |         outs = self.proj(feats) | ||||||
|  |         outs = self.proj_drop(outs) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor: | ||||||
|  |         feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads) | ||||||
|  |         outs = self.proj(feats) | ||||||
|  |         outs = self.proj_drop(outs) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "input_dim={:}, proj_dim={:}, num_heads={:}".format( | ||||||
|  |             (self.in_q_dim, self.in_k_dim, self.in_v_dim), | ||||||
|  |             self._proj_dim, | ||||||
|  |             self._num_heads, | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -24,7 +24,8 @@ super_name2norm = { | |||||||
|     "identity": SuperIdentity, |     "identity": SuperIdentity, | ||||||
| } | } | ||||||
|  |  | ||||||
| from .super_attention import SuperAttention | from .super_attention import SuperSelfAttention | ||||||
|  | from .super_attention import SuperQKVAttention | ||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|  |  | ||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
|   | |||||||
| @@ -35,11 +35,13 @@ class SuperDynamicPositionE(SuperModule): | |||||||
|         return self.forward_raw(input) |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         import pdb |         positions = torch.unsqueeze(input * self._scale, dim=-1) | ||||||
|  |         divisions = torch.reshape( | ||||||
|         pdb.set_trace() |             self._div_term, [1] * input.ndim + [self._div_term.numel()] | ||||||
|         print("---") |         ) | ||||||
|         return F.linear(input, self._super_weight, self._super_bias) |         values = positions / divisions | ||||||
|  |         embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1) | ||||||
|  |         return embeds | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "scale={:}, dim={:}".format(self._scale, self._dimension) |         return "scale={:}, dim={:}".format(self._scale, self._dimension) | ||||||
|   | |||||||
| @@ -19,7 +19,7 @@ from .super_module import LayerOrder | |||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_linear import SuperMLPv2 | from .super_linear import SuperMLPv2 | ||||||
| from .super_norm import SuperLayerNorm1D | from .super_norm import SuperLayerNorm1D | ||||||
| from .super_attention import SuperAttention | from .super_attention import SuperSelfAttention | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperTransformerEncoderLayer(SuperModule): | class SuperTransformerEncoderLayer(SuperModule): | ||||||
| @@ -47,7 +47,7 @@ class SuperTransformerEncoderLayer(SuperModule): | |||||||
|         order: LayerOrder = LayerOrder.PreNorm, |         order: LayerOrder = LayerOrder.PreNorm, | ||||||
|     ): |     ): | ||||||
|         super(SuperTransformerEncoderLayer, self).__init__() |         super(SuperTransformerEncoderLayer, self).__init__() | ||||||
|         mha = SuperAttention( |         mha = SuperSelfAttention( | ||||||
|             d_model, |             d_model, | ||||||
|             d_model, |             d_model, | ||||||
|             num_heads=num_heads, |             num_heads=num_heads, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user