Update LFNA
This commit is contained in:
parent
2a864ae705
commit
25dc78a7ce
@ -107,11 +107,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger):
|
|||||||
base_model.eval()
|
base_model.eval()
|
||||||
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||||
[seq_containers], _ = meta_model(time_seqs, None)
|
[seq_containers], _ = meta_model(time_seqs, None)
|
||||||
future_container = seq_containers[-2]
|
# For Debug
|
||||||
_, (future_x, future_y) = env(time_seqs[0, -2].item())
|
for idx in range(time_seqs.numel()):
|
||||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
future_container = seq_containers[idx]
|
||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
_, (future_x, future_y) = env(time_seqs[0, idx].item())
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||||
|
future_y_hat = base_model.forward_with_container(
|
||||||
|
future_x, future_container
|
||||||
|
)
|
||||||
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
|
logger.log(
|
||||||
|
"--> time={:.4f} -> loss={:.4f}".format(
|
||||||
|
time_seqs[0, idx].item(), future_loss.item()
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||||
idx, len(env), future_loss.item()
|
idx, len(env), future_loss.item()
|
||||||
|
@ -47,17 +47,17 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||||
|
|
||||||
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||||
time_embedding, scale=100
|
time_embedding, scale=500
|
||||||
)
|
)
|
||||||
|
|
||||||
# build transformer
|
# build transformer
|
||||||
self._trans_att = super_core.SuperQKVAttention(
|
self._trans_att = super_core.SuperQKVAttentionV2(
|
||||||
time_embedding,
|
qk_att_dim=time_embedding,
|
||||||
time_embedding,
|
in_v_dim=time_embedding,
|
||||||
time_embedding,
|
hidden_dim=time_embedding,
|
||||||
time_embedding,
|
num_heads=4,
|
||||||
4,
|
proj_dim=time_embedding,
|
||||||
True,
|
qkv_bias=True,
|
||||||
attn_drop=None,
|
attn_drop=None,
|
||||||
proj_drop=dropout,
|
proj_drop=dropout,
|
||||||
)
|
)
|
||||||
@ -166,9 +166,12 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
# timestamps is a batch of sequence of timestamps
|
# timestamps is a batch of sequence of timestamps
|
||||||
batch, seq = timestamps.shape
|
batch, seq = timestamps.shape
|
||||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
# timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
# timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||||
|
timestamp_qk_att_embed = self._tscalar_embed(
|
||||||
|
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
|
||||||
|
)
|
||||||
# create the mask
|
# create the mask
|
||||||
mask = (
|
mask = (
|
||||||
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
||||||
@ -179,11 +182,13 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
> self._thresh
|
> self._thresh
|
||||||
)
|
)
|
||||||
timestamp_embeds = self._trans_att(
|
timestamp_embeds = self._trans_att(
|
||||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
timestamp_qk_att_embed, timestamp_v_embed, mask
|
||||||
|
)
|
||||||
|
relative_timestamps = timestamps - timestamps[:, :1]
|
||||||
|
relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||||
|
init_timestamp_embeds = torch.cat(
|
||||||
|
(timestamp_embeds, relative_pos_embeds), dim=-1
|
||||||
)
|
)
|
||||||
# relative_timestamps = timestamps - timestamps[:, :1]
|
|
||||||
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
|
||||||
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
|
|
||||||
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
||||||
return corrected_embeds
|
return corrected_embeds
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ class SuperQKVAttention(SuperModule):
|
|||||||
return root_node
|
return root_node
|
||||||
|
|
||||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
super(SuperQVKAttention, self).apply_candidate(abstract_child)
|
super(SuperQKVAttention, self).apply_candidate(abstract_child)
|
||||||
if "q_fc" in abstract_child:
|
if "q_fc" in abstract_child:
|
||||||
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
||||||
if "k_fc" in abstract_child:
|
if "k_fc" in abstract_child:
|
||||||
|
117
xautodl/xlayers/super_attention_v2.py
Normal file
117
xautodl/xlayers/super_attention_v2.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Text
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
from xautodl import spaces
|
||||||
|
from .super_module import SuperModule
|
||||||
|
from .super_module import IntSpaceType
|
||||||
|
from .super_module import BoolSpaceType
|
||||||
|
from .super_linear import SuperLinear
|
||||||
|
|
||||||
|
|
||||||
|
class SuperQKVAttentionV2(SuperModule):
|
||||||
|
"""The super model for attention layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qk_att_dim: int,
|
||||||
|
in_v_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
proj_dim: int,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
attn_drop: Optional[float] = None,
|
||||||
|
proj_drop: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super(SuperQKVAttentionV2, self).__init__()
|
||||||
|
self._in_v_dim = in_v_dim
|
||||||
|
self._qk_att_dim = qk_att_dim
|
||||||
|
self._proj_dim = proj_dim
|
||||||
|
self._hidden_dim = hidden_dim
|
||||||
|
self._num_heads = num_heads
|
||||||
|
self._qkv_bias = qkv_bias
|
||||||
|
|
||||||
|
self.qk_fc = SuperLinear(qk_att_dim, num_heads, bias=qkv_bias)
|
||||||
|
self.v_fc = SuperLinear(in_v_dim, hidden_dim * num_heads, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop or 0.0)
|
||||||
|
self.proj = SuperLinear(hidden_dim * num_heads, proj_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop or 0.0)
|
||||||
|
self._infinity = 1e9
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_heads(self):
|
||||||
|
return spaces.get_max(self._num_heads)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_v_dim(self):
|
||||||
|
return spaces.get_max(self._in_v_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qk_att_dim(self):
|
||||||
|
return spaces.get_max(self._qk_att_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_dim(self):
|
||||||
|
return spaces.get_max(self._hidden_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def proj_dim(self):
|
||||||
|
return spaces.get_max(self._proj_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def abstract_search_space(self):
|
||||||
|
root_node = spaces.VirtualNode(id(self))
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
|
super(SuperQKVAttentionV2, self).apply_candidate(abstract_child)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward_qkv(
|
||||||
|
self, qk_att_tensor, v_tensor, num_head: int, mask=None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qk_att = self.qk_fc(qk_att_tensor)
|
||||||
|
B, N, S, _ = qk_att.shape
|
||||||
|
assert _ == num_head
|
||||||
|
attn_v1 = qk_att.permute(0, 3, 1, 2)
|
||||||
|
if mask is not None:
|
||||||
|
mask = torch.unsqueeze(mask, dim=1)
|
||||||
|
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
|
||||||
|
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
|
||||||
|
attn_v1 = self.attn_drop(attn_v1)
|
||||||
|
|
||||||
|
v = self.v_fc(v_tensor)
|
||||||
|
B0, _, _ = v.shape
|
||||||
|
v_v1 = v.reshape(B0, S, num_head, -1).permute(0, 2, 1, 3)
|
||||||
|
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
||||||
|
return feats_v1
|
||||||
|
|
||||||
|
def forward_candidate(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor:
|
||||||
|
return self.forward_raw(qk_att_tensor, v_tensor, mask)
|
||||||
|
|
||||||
|
def forward_raw(self, qk_att_tensor, v_tensor, mask=None) -> torch.Tensor:
|
||||||
|
feats = self.forward_qkv(qk_att_tensor, v_tensor, self.num_heads, mask)
|
||||||
|
outs = self.proj(feats)
|
||||||
|
outs = self.proj_drop(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "input_dim={:}, hidden_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
|
||||||
|
(self.qk_att_dim, self.in_v_dim),
|
||||||
|
self._hidden_dim,
|
||||||
|
self._proj_dim,
|
||||||
|
self._num_heads,
|
||||||
|
self._infinity,
|
||||||
|
)
|
@ -26,6 +26,7 @@ super_name2norm = {
|
|||||||
|
|
||||||
from .super_attention import SuperSelfAttention
|
from .super_attention import SuperSelfAttention
|
||||||
from .super_attention import SuperQKVAttention
|
from .super_attention import SuperQKVAttention
|
||||||
|
from .super_attention_v2 import SuperQKVAttentionV2
|
||||||
from .super_transformer import SuperTransformerEncoderLayer
|
from .super_transformer import SuperTransformerEncoderLayer
|
||||||
|
|
||||||
from .super_activations import SuperReLU
|
from .super_activations import SuperReLU
|
||||||
|
Loading…
Reference in New Issue
Block a user