Add SuperAttention
This commit is contained in:
parent
0c56a729ad
commit
e023a53c75
0
lib/trade_models/naive_v1_model.py
Executable file → Normal file
0
lib/trade_models/naive_v1_model.py
Executable file → Normal file
0
lib/trade_models/naive_v2_model.py
Executable file → Normal file
0
lib/trade_models/naive_v2_model.py
Executable file → Normal file
0
lib/trade_models/quant_transformer.py
Executable file → Normal file
0
lib/trade_models/quant_transformer.py
Executable file → Normal file
6
lib/trade_models/transformers.py
Executable file → Normal file
6
lib/trade_models/transformers.py
Executable file → Normal file
@ -1,6 +1,6 @@
|
|||||||
##################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
##################################################
|
#####################################################
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
155
lib/xlayers/super_attention.py
Normal file
155
lib/xlayers/super_attention.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Text
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
import spaces
|
||||||
|
from .super_module import SuperModule
|
||||||
|
from .super_module import IntSpaceType
|
||||||
|
from .super_module import BoolSpaceType
|
||||||
|
from .super_linear import SuperLinear
|
||||||
|
|
||||||
|
|
||||||
|
class SuperAttention(SuperModule):
|
||||||
|
"""The super model for attention layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: IntSpaceType,
|
||||||
|
proj_dim: IntSpaceType,
|
||||||
|
num_heads: IntSpaceType,
|
||||||
|
qkv_bias: BoolSpaceType = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
):
|
||||||
|
super(SuperAttention, self).__init__()
|
||||||
|
self._input_dim = input_dim
|
||||||
|
self._proj_dim = proj_dim
|
||||||
|
self._num_heads = num_heads
|
||||||
|
self._qkv_bias = qkv_bias
|
||||||
|
# head_dim = dim // num_heads
|
||||||
|
# self.scale = qk_scale or math.sqrt(head_dim)
|
||||||
|
|
||||||
|
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||||
|
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||||
|
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||||
|
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = SuperLinear(input_dim, proj_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_heads(self):
|
||||||
|
return spaces.get_max(self._num_heads)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dim(self):
|
||||||
|
return spaces.get_max(self._input_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def proj_dim(self):
|
||||||
|
return spaces.get_max(self._proj_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def abstract_search_space(self):
|
||||||
|
root_node = spaces.VirtualNode(id(self))
|
||||||
|
space_q = self.q_fc.abstract_search_space
|
||||||
|
space_k = self.k_fc.abstract_search_space
|
||||||
|
space_v = self.v_fc.abstract_search_space
|
||||||
|
space_proj = self.proj.abstract_search_space
|
||||||
|
if not spaces.is_determined(self._num_heads):
|
||||||
|
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
|
||||||
|
if not spaces.is_determined(space_q):
|
||||||
|
root_node.append("q_fc", space_q)
|
||||||
|
if not spaces.is_determined(space_k):
|
||||||
|
root_node.append("k_fc", space_k)
|
||||||
|
if not spaces.is_determined(space_v):
|
||||||
|
root_node.append("v_fc", space_v)
|
||||||
|
if not spaces.is_determined(space_proj):
|
||||||
|
root_node.append("proj", space_proj)
|
||||||
|
return root_node
|
||||||
|
|
||||||
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
|
super(SuperAttention, self).apply_candidate(abstract_child)
|
||||||
|
if "q_fc" in abstract_child:
|
||||||
|
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
||||||
|
if "k_fc" in abstract_child:
|
||||||
|
self.k_fc.apply_candidate(abstract_child["k_fc"])
|
||||||
|
if "v_fc" in abstract_child:
|
||||||
|
self.v_fc.apply_candidate(abstract_child["v_fc"])
|
||||||
|
if "proj" in abstract_child:
|
||||||
|
self.proj.apply_candidate(abstract_child["proj"])
|
||||||
|
|
||||||
|
def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor:
|
||||||
|
B, N, C = input.shape
|
||||||
|
q = self.q_fc(input)
|
||||||
|
k = self.k_fc(input)
|
||||||
|
v = self.v_fc(input)
|
||||||
|
if num_head > C:
|
||||||
|
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
|
||||||
|
head_dim = C // num_head
|
||||||
|
# process the first [num_head * head_dim] part
|
||||||
|
q_v1 = (
|
||||||
|
q[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B, N, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
k_v1 = (
|
||||||
|
k[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B, N, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
v_v1 = (
|
||||||
|
v[:, :, : num_head * head_dim]
|
||||||
|
.reshape(B, N, num_head, head_dim)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
)
|
||||||
|
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
||||||
|
attn_v1 = attn_v1.softmax(dim=-1)
|
||||||
|
attn_v1 = self.attn_drop(attn_v1)
|
||||||
|
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
||||||
|
if C == head_dim * num_head:
|
||||||
|
feats = feats_v1
|
||||||
|
else: # The channels can not be divided by num_head, the remainder forms an additional head
|
||||||
|
q_v2 = q[:, :, num_head * head_dim :]
|
||||||
|
k_v2 = k[:, :, num_head * head_dim :]
|
||||||
|
v_v2 = v[:, :, num_head * head_dim :]
|
||||||
|
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
|
||||||
|
attn_v2 = attn_v2.softmax(dim=-1)
|
||||||
|
attn_v2 = self.attn_drop(attn_v2)
|
||||||
|
feats_v2 = attn_v2 @ v_v2
|
||||||
|
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
# check the num_heads:
|
||||||
|
if not spaces.is_determined(self._num_heads):
|
||||||
|
num_heads = self.abstract_child["_num_heads"].value
|
||||||
|
else:
|
||||||
|
num_heads = spaces.get_determined_value(self._num_heads)
|
||||||
|
feats = self.forward_qkv(input, num_heads)
|
||||||
|
outs = self.proj(feats)
|
||||||
|
outs = self.proj_drop(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
feats = self.forward_qkv(input, self.num_heads)
|
||||||
|
outs = self.proj(feats)
|
||||||
|
outs = self.proj_drop(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
|
||||||
|
self._input_dim, self._proj_dim, self._num_heads
|
||||||
|
)
|
@ -5,3 +5,4 @@ from .super_module import SuperRunMode
|
|||||||
from .super_module import SuperModule
|
from .super_module import SuperModule
|
||||||
from .super_linear import SuperLinear
|
from .super_linear import SuperLinear
|
||||||
from .super_linear import SuperMLP
|
from .super_linear import SuperMLP
|
||||||
|
from .super_attention import SuperAttention
|
||||||
|
@ -6,14 +6,12 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Union, Callable
|
from typing import Optional, Callable
|
||||||
|
|
||||||
import spaces
|
import spaces
|
||||||
from .super_module import SuperModule
|
from .super_module import SuperModule
|
||||||
from .super_module import SuperRunMode
|
from .super_module import IntSpaceType
|
||||||
|
from .super_module import BoolSpaceType
|
||||||
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
|
|
||||||
BoolSpaceType = Union[bool, spaces.Categorical]
|
|
||||||
|
|
||||||
|
|
||||||
class SuperLinear(SuperModule):
|
class SuperLinear(SuperModule):
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
from typing import Optional, Union, Callable
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import spaces
|
import spaces
|
||||||
|
|
||||||
|
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
|
||||||
|
BoolSpaceType = Union[bool, spaces.Categorical]
|
||||||
|
|
||||||
|
|
||||||
class SuperRunMode(Enum):
|
class SuperRunMode(Enum):
|
||||||
"""This class defines the enumerations for Super Model Running Mode."""
|
"""This class defines the enumerations for Super Model Running Mode."""
|
||||||
@ -24,6 +29,7 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
super(SuperModule, self).__init__()
|
super(SuperModule, self).__init__()
|
||||||
self._super_run_type = SuperRunMode.Default
|
self._super_run_type = SuperRunMode.Default
|
||||||
self._abstract_child = None
|
self._abstract_child = None
|
||||||
|
self._verbose = False
|
||||||
|
|
||||||
def set_super_run_type(self, super_run_type):
|
def set_super_run_type(self, super_run_type):
|
||||||
def _reset_super_run(m):
|
def _reset_super_run(m):
|
||||||
@ -32,6 +38,13 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
|
|
||||||
self.apply(_reset_super_run)
|
self.apply(_reset_super_run)
|
||||||
|
|
||||||
|
def apply_verbose(self, verbose):
|
||||||
|
def _reset_verbose(m):
|
||||||
|
if isinstance(m, SuperModule):
|
||||||
|
m._verbose = verbose
|
||||||
|
|
||||||
|
self.apply(_reset_verbose)
|
||||||
|
|
||||||
def apply_candidate(self, abstract_child):
|
def apply_candidate(self, abstract_child):
|
||||||
if not isinstance(abstract_child, spaces.VirtualNode):
|
if not isinstance(abstract_child, spaces.VirtualNode):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -51,6 +64,10 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
def abstract_child(self):
|
def abstract_child(self):
|
||||||
return self._abstract_child
|
return self._abstract_child
|
||||||
|
|
||||||
|
@property
|
||||||
|
def verbose(self):
|
||||||
|
return self._verbose
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def forward_raw(self, *inputs):
|
def forward_raw(self, *inputs):
|
||||||
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
|
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
|
||||||
@ -60,12 +77,41 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
def forward_candidate(self, *inputs):
|
def forward_candidate(self, *inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name_with_id(self):
|
||||||
|
return "name={:}, id={:}".format(self.__class__.__name__, id(self))
|
||||||
|
|
||||||
|
def get_shape_str(self, tensors):
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
shapes = [self.get_shape_str(tensor) for tensor in tensors]
|
||||||
|
if len(shapes) == 1:
|
||||||
|
return shapes[0]
|
||||||
|
else:
|
||||||
|
return ", ".join(shapes)
|
||||||
|
elif isinstance(tensors, (torch.Tensor, nn.Parameter)):
|
||||||
|
return str(tuple(tensors.shape))
|
||||||
|
else:
|
||||||
|
raise TypeError("Invalid input type: {:}.".format(type(tensors)))
|
||||||
|
|
||||||
def forward(self, *inputs):
|
def forward(self, *inputs):
|
||||||
|
if self.verbose:
|
||||||
|
print(
|
||||||
|
"[{:}] inputs shape: {:}".format(
|
||||||
|
self.name_with_id, self.get_shape_str(inputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.super_run_type == SuperRunMode.FullModel:
|
if self.super_run_type == SuperRunMode.FullModel:
|
||||||
return self.forward_raw(*inputs)
|
outputs = self.forward_raw(*inputs)
|
||||||
elif self.super_run_type == SuperRunMode.Candidate:
|
elif self.super_run_type == SuperRunMode.Candidate:
|
||||||
return self.forward_candidate(*inputs)
|
outputs = self.forward_candidate(*inputs)
|
||||||
else:
|
else:
|
||||||
raise ModeError(
|
raise ModeError(
|
||||||
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
|
||||||
)
|
)
|
||||||
|
if self.verbose:
|
||||||
|
print(
|
||||||
|
"[{:}] outputs shape: {:}".format(
|
||||||
|
self.name_with_id, self.get_shape_str(outputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
@ -26,6 +26,7 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
bias = spaces.Categorical(True, False)
|
bias = spaces.Categorical(True, False)
|
||||||
model = super_core.SuperLinear(10, out_features, bias=bias)
|
model = super_core.SuperLinear(10, out_features, bias=bias)
|
||||||
print("The simple super linear module is:\n{:}".format(model))
|
print("The simple super linear module is:\n{:}".format(model))
|
||||||
|
model.apply_verbose(True)
|
||||||
|
|
||||||
print(model.super_run_type)
|
print(model.super_run_type)
|
||||||
self.assertTrue(model.bias)
|
self.assertTrue(model.bias)
|
||||||
@ -55,6 +56,7 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
out_features = spaces.Categorical(24, 36, 48)
|
out_features = spaces.Categorical(24, 36, 48)
|
||||||
mlp = super_core.SuperMLP(10, hidden_features, out_features)
|
mlp = super_core.SuperMLP(10, hidden_features, out_features)
|
||||||
print(mlp)
|
print(mlp)
|
||||||
|
mlp.apply_verbose(True)
|
||||||
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
|
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
|
||||||
|
|
||||||
inputs = torch.rand(4, 10)
|
inputs = torch.rand(4, 10)
|
||||||
@ -85,3 +87,29 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
outputs = mlp(inputs)
|
outputs = mlp(inputs)
|
||||||
output_shape = (4, abstract_child["fc2"]["_out_features"].value)
|
output_shape = (4, abstract_child["fc2"]["_out_features"].value)
|
||||||
self.assertEqual(tuple(outputs.shape), output_shape)
|
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||||
|
|
||||||
|
def test_super_attention(self):
|
||||||
|
proj_dim = spaces.Categorical(12, 24, 36)
|
||||||
|
num_heads = spaces.Categorical(2, 4, 6)
|
||||||
|
model = super_core.SuperAttention(10, proj_dim, num_heads)
|
||||||
|
print(model)
|
||||||
|
model.apply_verbose(True)
|
||||||
|
|
||||||
|
inputs = torch.rand(4, 20, 10) # batch size, sequence length, channel
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
abstract_space = model.abstract_search_space
|
||||||
|
print(
|
||||||
|
"The abstract search space for SuperAttention is:\n{:}".format(
|
||||||
|
abstract_space
|
||||||
|
)
|
||||||
|
)
|
||||||
|
abstract_space.clean_last()
|
||||||
|
abstract_child = abstract_space.random(reuse_last=True)
|
||||||
|
print("The abstract child program is:\n{:}".format(abstract_child))
|
||||||
|
|
||||||
|
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||||
|
model.apply_candidate(abstract_child)
|
||||||
|
outputs = model(inputs)
|
||||||
|
output_shape = (4, 20, abstract_child["proj"]["_out_features"].value)
|
||||||
|
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user