Fix bugs in xlayers
This commit is contained in:
parent
97717d826e
commit
bc42ab3c08
@ -10,7 +10,8 @@ from tqdm import tqdm
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / "..").resolve()
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
|
print("LIB-DIR: {:}".format(lib_dir))
|
||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
layer_embedding,
|
layer_embedding,
|
||||||
time_embedding,
|
time_embedding,
|
||||||
meta_timestamps,
|
meta_timestamps,
|
||||||
mha_depth: int = 2,
|
mha_depth: int = 1,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
super(LFNA_Meta, self).__init__()
|
super(LFNA_Meta, self).__init__()
|
||||||
@ -44,8 +44,21 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
self._append_meta_embed = dict(fixed=None, learnt=None)
|
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||||
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||||
|
|
||||||
self._time_prob_drop = super_core.SuperDrop(dropout, (-1, 1), recover=False)
|
self._tscalar_embed = super_core.SuperDynamicPositionE(
|
||||||
|
time_embedding, scale=100
|
||||||
|
)
|
||||||
|
|
||||||
# build transformer
|
# build transformer
|
||||||
|
self._trans_att = super_core.SuperQKVAttention(
|
||||||
|
time_embedding,
|
||||||
|
time_embedding,
|
||||||
|
time_embedding,
|
||||||
|
time_embedding,
|
||||||
|
4,
|
||||||
|
True,
|
||||||
|
attn_drop=None,
|
||||||
|
proj_drop=dropout,
|
||||||
|
)
|
||||||
layers = []
|
layers = []
|
||||||
for ilayer in range(mha_depth):
|
for ilayer in range(mha_depth):
|
||||||
layers.append(
|
layers.append(
|
||||||
@ -74,15 +87,9 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
self._generator = get_model(**model_kwargs)
|
self._generator = get_model(**model_kwargs)
|
||||||
# print("generator: {:}".format(self._generator))
|
# print("generator: {:}".format(self._generator))
|
||||||
|
|
||||||
# unknown token
|
|
||||||
self.register_parameter(
|
|
||||||
"_unknown_token",
|
|
||||||
torch.nn.Parameter(torch.Tensor(1, time_embedding)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialization
|
# initialization
|
||||||
trunc_normal_(
|
trunc_normal_(
|
||||||
[self._super_layer_embed, self._super_meta_embed, self._unknown_token],
|
[self._super_layer_embed, self._super_meta_embed],
|
||||||
std=0.02,
|
std=0.02,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,28 +143,21 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
(self._append_meta_embed["fixed"], meta_embed), dim=0
|
(self._append_meta_embed["fixed"], meta_embed), dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_raw(self, timestamps):
|
def _obtain_time_embed(self, timestamps):
|
||||||
# timestamps is a batch of sequence of timestamps
|
# timestamps is a batch of sequence of timestamps
|
||||||
batch, seq = timestamps.shape
|
batch, seq = timestamps.shape
|
||||||
timestamps = timestamps.unsqueeze(dim=-1)
|
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||||
meta_timestamps = self.meta_timestamps.view(1, 1, -1)
|
timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1))
|
||||||
time_diffs = timestamps - meta_timestamps
|
timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0)
|
||||||
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
|
timestamp_embeds = self._trans_att(
|
||||||
# select corresponding meta-knowledge
|
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed
|
||||||
meta_match = torch.index_select(
|
|
||||||
self.super_meta_embed, dim=0, index=time_match_i.view(-1)
|
|
||||||
)
|
)
|
||||||
meta_match = meta_match.view(batch, seq, -1)
|
corrected_embeds = self.meta_corrector(timestamp_embeds)
|
||||||
# create the probability
|
return corrected_embeds
|
||||||
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
|
|
||||||
|
|
||||||
x_time_probs = self._time_prob_drop(time_probs)
|
def forward_raw(self, timestamps):
|
||||||
# if self.training:
|
batch, seq = timestamps.shape
|
||||||
# time_probs[:, -1, :] = 0
|
meta_embed = self._obtain_time_embed(timestamps)
|
||||||
unknown_token = self._unknown_token.view(1, 1, -1)
|
|
||||||
raw_meta_embed = x_time_probs * meta_match + (1 - x_time_probs) * unknown_token
|
|
||||||
|
|
||||||
meta_embed = self.meta_corrector(raw_meta_embed)
|
|
||||||
# create joint embed
|
# create joint embed
|
||||||
num_layer, _ = self._super_layer_embed.shape
|
num_layer, _ = self._super_layer_embed.shape
|
||||||
meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1)
|
meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1)
|
||||||
|
2
setup.py
2
setup.py
@ -16,7 +16,7 @@
|
|||||||
#
|
#
|
||||||
# TODO(xuanyidong): upload it to conda
|
# TODO(xuanyidong): upload it to conda
|
||||||
#
|
#
|
||||||
# [2021.05.18] v1.0
|
# [2021.05.21] v0.9.9
|
||||||
import os
|
import os
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from .super_module import BoolSpaceType
|
|||||||
from .super_linear import SuperLinear
|
from .super_linear import SuperLinear
|
||||||
|
|
||||||
|
|
||||||
class SuperAttention(SuperModule):
|
class SuperSelfAttention(SuperModule):
|
||||||
"""The super model for attention layer."""
|
"""The super model for attention layer."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -32,7 +32,7 @@ class SuperAttention(SuperModule):
|
|||||||
attn_drop: Optional[float] = None,
|
attn_drop: Optional[float] = None,
|
||||||
proj_drop: Optional[float] = None,
|
proj_drop: Optional[float] = None,
|
||||||
):
|
):
|
||||||
super(SuperAttention, self).__init__()
|
super(SuperSelfAttention, self).__init__()
|
||||||
self._input_dim = input_dim
|
self._input_dim = input_dim
|
||||||
self._proj_dim = proj_dim
|
self._proj_dim = proj_dim
|
||||||
self._num_heads = num_heads
|
self._num_heads = num_heads
|
||||||
@ -150,3 +150,157 @@ class SuperAttention(SuperModule):
|
|||||||
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
||||||
self._input_dim, self._proj_dim, self._num_heads
|
self._input_dim, self._proj_dim, self._num_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SuperQKVAttention(SuperModule):
|
||||||
|
"""The super model for attention layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_q_dim: IntSpaceType,
|
||||||
|
in_k_dim: IntSpaceType,
|
||||||
|
in_v_dim: IntSpaceType,
|
||||||
|
proj_dim: IntSpaceType,
|
||||||
|
num_heads: IntSpaceType,
|
||||||
|
qkv_bias: BoolSpaceType = False,
|
||||||
|
attn_drop: Optional[float] = None,
|
||||||
|
proj_drop: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super(SuperQKVAttention, self).__init__()
|
||||||
|
self._in_v_dim = in_v_dim
|
||||||
|
self._in_q_dim = in_q_dim
|
||||||
|
self._in_k_dim = in_k_dim
|
||||||
|
self._proj_dim = proj_dim
|
||||||
|
self._num_heads = num_heads
|
||||||
|
self._qkv_bias = qkv_bias
|
||||||
|
|
||||||
|
self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias)
|
||||||
|
self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias)
|
||||||
|
self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop or 0.0)
|
||||||
|
self.proj = SuperLinear(proj_dim, proj_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop or 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_heads(self):
|
||||||
|
return spaces.get_max(self._num_heads)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_v_dim(self):
|
||||||
|
return spaces.get_max(self._in_v_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_q_dim(self):
|
||||||
|
return spaces.get_max(self._in_q_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_k_dim(self):
|
||||||
|
return spaces.get_max(self._in_k_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def proj_dim(self):
|
||||||
|
return spaces.get_max(self._proj_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def abstract_search_space(self):
|
||||||
|
root_node = spaces.VirtualNode(id(self))
|
||||||
|
space_q = self.q_fc.abstract_search_space
|
||||||
|
space_k = self.k_fc.abstract_search_space
|
||||||
|
space_v = self.v_fc.abstract_search_space
|
||||||
|
space_proj = self.proj.abstract_search_space
|
||||||
|
if not spaces.is_determined(self._num_heads):
|
||||||
|
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
|
||||||
|
if not spaces.is_determined(space_q):
|
||||||
|
root_node.append("q_fc", space_q)
|
||||||
|
if not spaces.is_determined(space_k):
|
||||||
|
root_node.append("k_fc", space_k)
|
||||||
|
if not spaces.is_determined(space_v):
|
||||||
|
root_node.append("v_fc", space_v)
|
||||||
|
if not spaces.is_determined(space_proj):
|
||||||
|
root_node.append("proj", space_proj)
|
||||||
|
return root_node
|
||||||
|
|
||||||
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
|
super(SuperAttention, self).apply_candidate(abstract_child)
|
||||||
|
if "q_fc" in abstract_child:
|
||||||
|
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
||||||
|
if "k_fc" in abstract_child:
|
||||||
|
self.k_fc.apply_candidate(abstract_child["k_fc"])
|
||||||
|
if "v_fc" in abstract_child:
|
||||||
|
self.v_fc.apply_candidate(abstract_child["v_fc"])
|
||||||
|
if "proj" in abstract_child:
|
||||||
|
self.proj.apply_candidate(abstract_child["proj"])
|
||||||
|
|
||||||
|
def forward_qkv(self, q_tensor, k_tensor, v_tensor, num_head: int) -> torch.Tensor:
|
||||||
|
q = self.q_fc(q_tensor)
|
||||||
|
B, N, C = q.shape
|
||||||
|
|
||||||
|
k = self.k_fc(k_tensor)
|
||||||
|
B0, S, _ = k.shape
|
||||||
|
|
||||||
|
v = self.v_fc(v_tensor)
|
||||||
|
assert B0 == v.shape[0] and S == v.shape[1]
|
||||||
|
|
||||||
|
head_dim = C // num_head
|
||||||
|
if num_head > C:
|
||||||
|
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
|
||||||
|
q_v1 = (
|
||||||
|
q[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B, N, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
k_v1 = (
|
||||||
|
k[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B0, S, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
# compute the attention map
|
||||||
|
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
||||||
|
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
|
||||||
|
attn_v1 = self.attn_drop(attn_v1)
|
||||||
|
|
||||||
|
v_v1 = (
|
||||||
|
v[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B0, S, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
||||||
|
# process the first [num_head * head_dim] part
|
||||||
|
if C == head_dim * num_head:
|
||||||
|
feats = feats_v1
|
||||||
|
else: # The channels can not be divided by num_head, the remainder forms an additional head
|
||||||
|
# [might have bugs, did not check yet]
|
||||||
|
q_v2 = q[:, :, num_head * head_dim :]
|
||||||
|
k_v2 = k[:, :, num_head * head_dim :]
|
||||||
|
v_v2 = v[:, :, num_head * head_dim :]
|
||||||
|
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
|
||||||
|
attn_v2 = attn_v2.softmax(dim=-1)
|
||||||
|
attn_v2 = self.attn_drop(attn_v2)
|
||||||
|
feats_v2 = attn_v2 @ v_v2
|
||||||
|
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
def forward_candidate(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
|
||||||
|
# check the num_heads:
|
||||||
|
if not spaces.is_determined(self._num_heads):
|
||||||
|
num_heads = self.abstract_child["_num_heads"].value
|
||||||
|
else:
|
||||||
|
num_heads = spaces.get_determined_value(self._num_heads)
|
||||||
|
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads)
|
||||||
|
outs = self.proj(feats)
|
||||||
|
outs = self.proj_drop(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def forward_raw(self, q_tensor, k_tensor, v_tensor) -> torch.Tensor:
|
||||||
|
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads)
|
||||||
|
outs = self.proj(feats)
|
||||||
|
outs = self.proj_drop(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
||||||
|
(self.in_q_dim, self.in_k_dim, self.in_v_dim),
|
||||||
|
self._proj_dim,
|
||||||
|
self._num_heads,
|
||||||
|
)
|
||||||
|
@ -24,7 +24,8 @@ super_name2norm = {
|
|||||||
"identity": SuperIdentity,
|
"identity": SuperIdentity,
|
||||||
}
|
}
|
||||||
|
|
||||||
from .super_attention import SuperAttention
|
from .super_attention import SuperSelfAttention
|
||||||
|
from .super_attention import SuperQKVAttention
|
||||||
from .super_transformer import SuperTransformerEncoderLayer
|
from .super_transformer import SuperTransformerEncoderLayer
|
||||||
|
|
||||||
from .super_activations import SuperReLU
|
from .super_activations import SuperReLU
|
||||||
|
@ -35,11 +35,13 @@ class SuperDynamicPositionE(SuperModule):
|
|||||||
return self.forward_raw(input)
|
return self.forward_raw(input)
|
||||||
|
|
||||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
import pdb
|
positions = torch.unsqueeze(input * self._scale, dim=-1)
|
||||||
|
divisions = torch.reshape(
|
||||||
pdb.set_trace()
|
self._div_term, [1] * input.ndim + [self._div_term.numel()]
|
||||||
print("---")
|
)
|
||||||
return F.linear(input, self._super_weight, self._super_bias)
|
values = positions / divisions
|
||||||
|
embeds = torch.cat((torch.sin(values), torch.cos(values)), dim=-1)
|
||||||
|
return embeds
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return "scale={:}, dim={:}".format(self._scale, self._dimension)
|
return "scale={:}, dim={:}".format(self._scale, self._dimension)
|
||||||
|
@ -19,7 +19,7 @@ from .super_module import LayerOrder
|
|||||||
from .super_module import SuperModule
|
from .super_module import SuperModule
|
||||||
from .super_linear import SuperMLPv2
|
from .super_linear import SuperMLPv2
|
||||||
from .super_norm import SuperLayerNorm1D
|
from .super_norm import SuperLayerNorm1D
|
||||||
from .super_attention import SuperAttention
|
from .super_attention import SuperSelfAttention
|
||||||
|
|
||||||
|
|
||||||
class SuperTransformerEncoderLayer(SuperModule):
|
class SuperTransformerEncoderLayer(SuperModule):
|
||||||
@ -47,7 +47,7 @@ class SuperTransformerEncoderLayer(SuperModule):
|
|||||||
order: LayerOrder = LayerOrder.PreNorm,
|
order: LayerOrder = LayerOrder.PreNorm,
|
||||||
):
|
):
|
||||||
super(SuperTransformerEncoderLayer, self).__init__()
|
super(SuperTransformerEncoderLayer, self).__init__()
|
||||||
mha = SuperAttention(
|
mha = SuperSelfAttention(
|
||||||
d_model,
|
d_model,
|
||||||
d_model,
|
d_model,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
Loading…
Reference in New Issue
Block a user