From 25dc78a7ce96212a5ca82f339b1c22af058a64aa Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 23 May 2021 08:21:31 +0000 Subject: [PATCH] Update LFNA --- exps/LFNA/lfna.py | 19 +++-- exps/LFNA/lfna_meta_model.py | 33 +++++--- xautodl/xlayers/super_attention.py | 2 +- xautodl/xlayers/super_attention_v2.py | 117 ++++++++++++++++++++++++++ xautodl/xlayers/super_core.py | 1 + 5 files changed, 152 insertions(+), 20 deletions(-) create mode 100644 xautodl/xlayers/super_attention_v2.py diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index ec7d3c1..959edc1 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): base_model.eval() time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) [seq_containers], _ = meta_model(time_seqs, None) - future_container = seq_containers[-2] - _, (future_x, future_y) = env(time_seqs[0, -2].item()) - future_x, future_y = future_x.to(args.device), future_y.to(args.device) - future_y_hat = base_model.forward_with_container(future_x, future_container) - future_loss = criterion(future_y_hat, future_y) + # For Debug + for idx in range(time_seqs.numel()): + future_container = seq_containers[idx] + _, (future_x, future_y) = env(time_seqs[0, idx].item()) + future_x, future_y = future_x.to(args.device), future_y.to(args.device) + future_y_hat = base_model.forward_with_container( + future_x, future_container + ) + future_loss = criterion(future_y_hat, future_y) + logger.log( + "--> time={:.4f} -> loss={:.4f}".format( + time_seqs[0, idx].item(), future_loss.item() + ) + ) logger.log( "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( idx, len(env), future_loss.item() diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 765c890..5516b50 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule): self._append_meta_timestamps = dict(fixed=None, learnt=None) self._tscalar_embed = super_core.SuperDynamicPositionE( - time_embedding, scale=100 + time_embedding, scale=500 ) # build transformer - self._trans_att = super_core.SuperQKVAttention( - time_embedding, - time_embedding, - time_embedding, - time_embedding, - 4, - True, + self._trans_att = super_core.SuperQKVAttentionV2( + qk_att_dim=time_embedding, + in_v_dim=time_embedding, + hidden_dim=time_embedding, + num_heads=4, + proj_dim=time_embedding, + qkv_bias=True, attn_drop=None, proj_drop=dropout, ) @@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule): # timestamps is a batch of sequence of timestamps batch, seq = timestamps.shape meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed - timestamp_q_embed = self._tscalar_embed(timestamps) - timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) + # timestamp_q_embed = self._tscalar_embed(timestamps) + # timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) timestamp_v_embed = meta_embeds.unsqueeze(dim=0) + timestamp_qk_att_embed = self._tscalar_embed( + torch.unsqueeze(timestamps, dim=-1) - meta_timestamps + ) # create the mask mask = ( torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) @@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule): > self._thresh ) timestamp_embeds = self._trans_att( - timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask + timestamp_qk_att_embed, timestamp_v_embed, mask + ) + relative_timestamps = timestamps - timestamps[:, :1] + relative_pos_embeds = self._tscalar_embed(relative_timestamps) + init_timestamp_embeds = torch.cat( + (timestamp_embeds, relative_pos_embeds), dim=-1 ) - # relative_timestamps = timestamps - timestamps[:, :1] - # relative_pos_embeds = self._tscalar_embed(relative_timestamps) - init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) corrected_embeds = self._meta_corrector(init_timestamp_embeds) return corrected_embeds diff --git a/xautodl/xlayers/super_attention.py b/xautodl/xlayers/super_attention.py index 1adfae6..77075a4 100644 --- a/xautodl/xlayers/super_attention.py +++ b/xautodl/xlayers/super_attention.py @@ -238,7 +238,7 @@ class SuperQKVAttention(SuperModule): return root_node def apply_candidate(self, abstract_child: spaces.VirtualNode): - super(SuperQVKAttention, self).apply_candidate(abstract_child) + super(SuperQKVAttention, self).apply_candidate(abstract_child) if "q_fc" in abstract_child: self.q_fc.apply_candidate(abstract_child["q_fc"]) if "k_fc" in abstract_child: diff --git a/xautodl/xlayers/super_attention_v2.py b/xautodl/xlayers/super_attention_v2.py new file mode 100644 index 0000000..3f0f49f --- /dev/null +++ b/xautodl/xlayers/super_attention_v2.py @@ -0,0 +1,117 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +from __future__ import division +from __future__ import print_function + +import math +from functools import partial +from typing import Optional, Text + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from xautodl import spaces +from .super_module import SuperModule +from .super_module import IntSpaceType +from .super_module import BoolSpaceType +from .super_linear import SuperLinear + + +class SuperQKVAttentionV2(SuperModule): + """The super model for attention layer.""" + + def __init__( + self, + qk_att_dim: int, + in_v_dim: int, + hidden_dim: int, + num_heads: int, + proj_dim: int, + qkv_bias: bool = False, + attn_drop: Optional[float] = None, + proj_drop: Optional[float] = None, + ): + super(SuperQKVAttentionV2, self).__init__() + self._in_v_dim = in_v_dim + self._qk_att_dim = qk_att_dim + self._proj_dim = proj_dim + self._hidden_dim = hidden_dim + self._num_heads = num_heads + self._qkv_bias = qkv_bias + + self.qk_fc = SuperLinear(qk_att_dim, num_heads, bias=qkv_bias) + self.v_fc = SuperLinear(in_v_dim, hidden_dim * num_heads, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop or 0.0) + self.proj = SuperLinear(hidden_dim * num_heads, proj_dim) + self.proj_drop = nn.Dropout(proj_drop or 0.0) + self._infinity = 1e9 + + @property + def num_heads(self): + return spaces.get_max(self._num_heads) + + @property + def in_v_dim(self): + return spaces.get_max(self._in_v_dim) + + @property + def qk_att_dim(self): + return spaces.get_max(self._qk_att_dim) + + @property + def hidden_dim(self): + return spaces.get_max(self._hidden_dim) + + @property + def proj_dim(self): + return spaces.get_max(self._proj_dim) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + raise NotImplementedError + + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperQKVAttentionV2, self).apply_candidate(abstract_child) + raise NotImplementedError + + def forward_qkv( + self, qk_att_tensor, v_tensor, num_head: int, mask=None + ) -> torch.Tensor: + qk_att = self.qk_fc(qk_att_tensor) + B, N, S, _ = qk_att.shape + assert _ == num_head + attn_v1 = qk_att.permute(0, 3, 1, 2) + if mask is not None: + mask = torch.unsqueeze(mask, dim=1) + attn_v1 = attn_v1.masked_fill(mask, -self._infinity) + attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S + attn_v1 = self.attn_drop(attn_v1) + + v = self.v_fc(v_tensor) + B0, _, _ = v.shape + v_v1 = v.reshape(B0, S, num_head, -1).permute(0, 2, 1, 3) + feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) + return feats_v1 + + def forward_candidate(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor: + return self.forward_raw(qk_att_tensor, v_tensor, mask) + + def forward_raw(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor: + feats = self.forward_qkv(qk_att_tensor, v_tensor, self.num_heads, mask) + outs = self.proj(feats) + outs = self.proj_drop(outs) + return outs + + def extra_repr(self) -> str: + return "input_dim={:}, hidden_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format( + (self.qk_att_dim, self.in_v_dim), + self._hidden_dim, + self._proj_dim, + self._num_heads, + self._infinity, + ) diff --git a/xautodl/xlayers/super_core.py b/xautodl/xlayers/super_core.py index 4544280..7c026a6 100644 --- a/xautodl/xlayers/super_core.py +++ b/xautodl/xlayers/super_core.py @@ -26,6 +26,7 @@ super_name2norm = { from .super_attention import SuperSelfAttention from .super_attention import SuperQKVAttention +from .super_attention_v2 import SuperQKVAttentionV2 from .super_transformer import SuperTransformerEncoderLayer from .super_activations import SuperReLU