Update LFNA
This commit is contained in:
		| @@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) |             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) | ||||||
|             [seq_containers], _ = meta_model(time_seqs, None) |             [seq_containers], _ = meta_model(time_seqs, None) | ||||||
|             future_container = seq_containers[-2] |             # For Debug | ||||||
|             _, (future_x, future_y) = env(time_seqs[0, -2].item()) |             for idx in range(time_seqs.numel()): | ||||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) |                 future_container = seq_containers[idx] | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |                 _, (future_x, future_y) = env(time_seqs[0, idx].item()) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |                 future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|  |                 future_y_hat = base_model.forward_with_container( | ||||||
|  |                     future_x, future_container | ||||||
|  |                 ) | ||||||
|  |                 future_loss = criterion(future_y_hat, future_y) | ||||||
|  |                 logger.log( | ||||||
|  |                     "--> time={:.4f} -> loss={:.4f}".format( | ||||||
|  |                         time_seqs[0, idx].item(), future_loss.item() | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|             logger.log( |             logger.log( | ||||||
|                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( |                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||||
|                     idx, len(env), future_loss.item() |                     idx, len(env), future_loss.item() | ||||||
|   | |||||||
| @@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
|  |  | ||||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( |         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||||
|             time_embedding, scale=100 |             time_embedding, scale=500 | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # build transformer |         # build transformer | ||||||
|         self._trans_att = super_core.SuperQKVAttention( |         self._trans_att = super_core.SuperQKVAttentionV2( | ||||||
|             time_embedding, |             qk_att_dim=time_embedding, | ||||||
|             time_embedding, |             in_v_dim=time_embedding, | ||||||
|             time_embedding, |             hidden_dim=time_embedding, | ||||||
|             time_embedding, |             num_heads=4, | ||||||
|             4, |             proj_dim=time_embedding, | ||||||
|             True, |             qkv_bias=True, | ||||||
|             attn_drop=None, |             attn_drop=None, | ||||||
|             proj_drop=dropout, |             proj_drop=dropout, | ||||||
|         ) |         ) | ||||||
| @@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) |         # timestamp_q_embed = self._tscalar_embed(timestamps) | ||||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) |         # timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|  |         timestamp_qk_att_embed = self._tscalar_embed( | ||||||
|  |             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||||
|  |         ) | ||||||
|         # create the mask |         # create the mask | ||||||
|         mask = ( |         mask = ( | ||||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) |             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||||
| @@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             > self._thresh |             > self._thresh | ||||||
|         ) |         ) | ||||||
|         timestamp_embeds = self._trans_att( |         timestamp_embeds = self._trans_att( | ||||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask |             timestamp_qk_att_embed, timestamp_v_embed, mask | ||||||
|  |         ) | ||||||
|  |         relative_timestamps = timestamps - timestamps[:, :1] | ||||||
|  |         relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||||
|  |         init_timestamp_embeds = torch.cat( | ||||||
|  |             (timestamp_embeds, relative_pos_embeds), dim=-1 | ||||||
|         ) |         ) | ||||||
|         # relative_timestamps = timestamps - timestamps[:, :1] |  | ||||||
|         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) |  | ||||||
|         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) |  | ||||||
|         corrected_embeds = self._meta_corrector(init_timestamp_embeds) |         corrected_embeds = self._meta_corrector(init_timestamp_embeds) | ||||||
|         return corrected_embeds |         return corrected_embeds | ||||||
|  |  | ||||||
|   | |||||||
| @@ -238,7 +238,7 @@ class SuperQKVAttention(SuperModule): | |||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|         super(SuperQVKAttention, self).apply_candidate(abstract_child) |         super(SuperQKVAttention, self).apply_candidate(abstract_child) | ||||||
|         if "q_fc" in abstract_child: |         if "q_fc" in abstract_child: | ||||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) |             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||||
|         if "k_fc" in abstract_child: |         if "k_fc" in abstract_child: | ||||||
|   | |||||||
							
								
								
									
										117
									
								
								xautodl/xlayers/super_attention_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								xautodl/xlayers/super_attention_v2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | from __future__ import division | ||||||
|  | from __future__ import print_function | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | from functools import partial | ||||||
|  | from typing import Optional, Text | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  |  | ||||||
|  | from xautodl import spaces | ||||||
|  | from .super_module import SuperModule | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  | from .super_module import BoolSpaceType | ||||||
|  | from .super_linear import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperQKVAttentionV2(SuperModule): | ||||||
|  |     """The super model for attention layer.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         qk_att_dim: int, | ||||||
|  |         in_v_dim: int, | ||||||
|  |         hidden_dim: int, | ||||||
|  |         num_heads: int, | ||||||
|  |         proj_dim: int, | ||||||
|  |         qkv_bias: bool = False, | ||||||
|  |         attn_drop: Optional[float] = None, | ||||||
|  |         proj_drop: Optional[float] = None, | ||||||
|  |     ): | ||||||
|  |         super(SuperQKVAttentionV2, self).__init__() | ||||||
|  |         self._in_v_dim = in_v_dim | ||||||
|  |         self._qk_att_dim = qk_att_dim | ||||||
|  |         self._proj_dim = proj_dim | ||||||
|  |         self._hidden_dim = hidden_dim | ||||||
|  |         self._num_heads = num_heads | ||||||
|  |         self._qkv_bias = qkv_bias | ||||||
|  |  | ||||||
|  |         self.qk_fc = SuperLinear(qk_att_dim, num_heads, bias=qkv_bias) | ||||||
|  |         self.v_fc = SuperLinear(in_v_dim, hidden_dim * num_heads, bias=qkv_bias) | ||||||
|  |  | ||||||
|  |         self.attn_drop = nn.Dropout(attn_drop or 0.0) | ||||||
|  |         self.proj = SuperLinear(hidden_dim * num_heads, proj_dim) | ||||||
|  |         self.proj_drop = nn.Dropout(proj_drop or 0.0) | ||||||
|  |         self._infinity = 1e9 | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def num_heads(self): | ||||||
|  |         return spaces.get_max(self._num_heads) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def in_v_dim(self): | ||||||
|  |         return spaces.get_max(self._in_v_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def qk_att_dim(self): | ||||||
|  |         return spaces.get_max(self._qk_att_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def hidden_dim(self): | ||||||
|  |         return spaces.get_max(self._hidden_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def proj_dim(self): | ||||||
|  |         return spaces.get_max(self._proj_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperQKVAttentionV2, self).apply_candidate(abstract_child) | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def forward_qkv( | ||||||
|  |         self, qk_att_tensor, v_tensor, num_head: int, mask=None | ||||||
|  |     ) -> torch.Tensor: | ||||||
|  |         qk_att = self.qk_fc(qk_att_tensor) | ||||||
|  |         B, N, S, _ = qk_att.shape | ||||||
|  |         assert _ == num_head | ||||||
|  |         attn_v1 = qk_att.permute(0, 3, 1, 2) | ||||||
|  |         if mask is not None: | ||||||
|  |             mask = torch.unsqueeze(mask, dim=1) | ||||||
|  |             attn_v1 = attn_v1.masked_fill(mask, -self._infinity) | ||||||
|  |         attn_v1 = attn_v1.softmax(dim=-1)  # B * #head * N * S | ||||||
|  |         attn_v1 = self.attn_drop(attn_v1) | ||||||
|  |  | ||||||
|  |         v = self.v_fc(v_tensor) | ||||||
|  |         B0, _, _ = v.shape | ||||||
|  |         v_v1 = v.reshape(B0, S, num_head, -1).permute(0, 2, 1, 3) | ||||||
|  |         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||||
|  |         return feats_v1 | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor: | ||||||
|  |         return self.forward_raw(qk_att_tensor, v_tensor, mask) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor: | ||||||
|  |         feats = self.forward_qkv(qk_att_tensor, v_tensor, self.num_heads, mask) | ||||||
|  |         outs = self.proj(feats) | ||||||
|  |         outs = self.proj_drop(outs) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "input_dim={:}, hidden_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format( | ||||||
|  |             (self.qk_att_dim, self.in_v_dim), | ||||||
|  |             self._hidden_dim, | ||||||
|  |             self._proj_dim, | ||||||
|  |             self._num_heads, | ||||||
|  |             self._infinity, | ||||||
|  |         ) | ||||||
| @@ -26,6 +26,7 @@ super_name2norm = { | |||||||
|  |  | ||||||
| from .super_attention import SuperSelfAttention | from .super_attention import SuperSelfAttention | ||||||
| from .super_attention import SuperQKVAttention | from .super_attention import SuperQKVAttention | ||||||
|  | from .super_attention_v2 import SuperQKVAttentionV2 | ||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|  |  | ||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user