Add SuperAttention

This commit is contained in:
D-X-Y 2021-03-20 15:56:37 +08:00
parent 0c56a729ad
commit e023a53c75
9 changed files with 239 additions and 11 deletions

0
lib/trade_models/naive_v1_model.py Executable file → Normal file
View File

0
lib/trade_models/naive_v2_model.py Executable file → Normal file
View File

0
lib/trade_models/quant_transformer.py Executable file → Normal file
View File

6
lib/trade_models/transformers.py Executable file → Normal file
View File

@ -1,6 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
##################################################
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,155 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from __future__ import division
from __future__ import print_function
import math
from functools import partial
from typing import Optional, Text
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_linear import SuperLinear
class SuperAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
input_dim: IntSpaceType,
proj_dim: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super(SuperAttention, self).__init__()
self._input_dim = input_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
# head_dim = dim // num_heads
# self.scale = qk_scale or math.sqrt(head_dim)
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop)
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def input_dim(self):
return spaces.get_max(self._input_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
root_node.append("q_fc", space_q)
if not spaces.is_determined(space_k):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
self.k_fc.apply_candidate(abstract_child["k_fc"])
if "v_fc" in abstract_child:
self.v_fc.apply_candidate(abstract_child["v_fc"])
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor:
B, N, C = input.shape
q = self.q_fc(input)
k = self.k_fc(input)
v = self.v_fc(input)
if num_head > C:
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
head_dim = C // num_head
# process the first [num_head * head_dim] part
q_v1 = (
q[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
k_v1 = (
k[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
v_v1 = (
v[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
attn_v1 = attn_v1.softmax(dim=-1)
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check the num_heads:
if not spaces.is_determined(self._num_heads):
num_heads = self.abstract_child["_num_heads"].value
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(input, num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
feats = self.forward_qkv(input, self.num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
self._input_dim, self._proj_dim, self._num_heads
)

View File

@ -5,3 +5,4 @@ from .super_module import SuperRunMode
from .super_module import SuperModule
from .super_linear import SuperLinear
from .super_linear import SuperMLP
from .super_attention import SuperAttention

View File

@ -6,14 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Union, Callable
from typing import Optional, Callable
import spaces
from .super_module import SuperModule
from .super_module import SuperRunMode
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
class SuperLinear(SuperModule):

View File

@ -1,13 +1,18 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import abc
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from enum import Enum
import spaces
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class SuperRunMode(Enum):
"""This class defines the enumerations for Super Model Running Mode."""
@ -24,6 +29,7 @@ class SuperModule(abc.ABC, nn.Module):
super(SuperModule, self).__init__()
self._super_run_type = SuperRunMode.Default
self._abstract_child = None
self._verbose = False
def set_super_run_type(self, super_run_type):
def _reset_super_run(m):
@ -32,6 +38,13 @@ class SuperModule(abc.ABC, nn.Module):
self.apply(_reset_super_run)
def apply_verbose(self, verbose):
def _reset_verbose(m):
if isinstance(m, SuperModule):
m._verbose = verbose
self.apply(_reset_verbose)
def apply_candidate(self, abstract_child):
if not isinstance(abstract_child, spaces.VirtualNode):
raise ValueError(
@ -51,6 +64,10 @@ class SuperModule(abc.ABC, nn.Module):
def abstract_child(self):
return self._abstract_child
@property
def verbose(self):
return self._verbose
@abc.abstractmethod
def forward_raw(self, *inputs):
"""Use the largest candidate for forward. Similar to the original PyTorch model."""
@ -60,12 +77,41 @@ class SuperModule(abc.ABC, nn.Module):
def forward_candidate(self, *inputs):
raise NotImplementedError
@property
def name_with_id(self):
return "name={:}, id={:}".format(self.__class__.__name__, id(self))
def get_shape_str(self, tensors):
if isinstance(tensors, (list, tuple)):
shapes = [self.get_shape_str(tensor) for tensor in tensors]
if len(shapes) == 1:
return shapes[0]
else:
return ", ".join(shapes)
elif isinstance(tensors, (torch.Tensor, nn.Parameter)):
return str(tuple(tensors.shape))
else:
raise TypeError("Invalid input type: {:}.".format(type(tensors)))
def forward(self, *inputs):
if self.verbose:
print(
"[{:}] inputs shape: {:}".format(
self.name_with_id, self.get_shape_str(inputs)
)
)
if self.super_run_type == SuperRunMode.FullModel:
return self.forward_raw(*inputs)
outputs = self.forward_raw(*inputs)
elif self.super_run_type == SuperRunMode.Candidate:
return self.forward_candidate(*inputs)
outputs = self.forward_candidate(*inputs)
else:
raise ModeError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
)
if self.verbose:
print(
"[{:}] outputs shape: {:}".format(
self.name_with_id, self.get_shape_str(outputs)
)
)
return outputs

View File

@ -26,6 +26,7 @@ class TestSuperLinear(unittest.TestCase):
bias = spaces.Categorical(True, False)
model = super_core.SuperLinear(10, out_features, bias=bias)
print("The simple super linear module is:\n{:}".format(model))
model.apply_verbose(True)
print(model.super_run_type)
self.assertTrue(model.bias)
@ -55,6 +56,7 @@ class TestSuperLinear(unittest.TestCase):
out_features = spaces.Categorical(24, 36, 48)
mlp = super_core.SuperMLP(10, hidden_features, out_features)
print(mlp)
mlp.apply_verbose(True)
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
inputs = torch.rand(4, 10)
@ -85,3 +87,29 @@ class TestSuperLinear(unittest.TestCase):
outputs = mlp(inputs)
output_shape = (4, abstract_child["fc2"]["_out_features"].value)
self.assertEqual(tuple(outputs.shape), output_shape)
def test_super_attention(self):
proj_dim = spaces.Categorical(12, 24, 36)
num_heads = spaces.Categorical(2, 4, 6)
model = super_core.SuperAttention(10, proj_dim, num_heads)
print(model)
model.apply_verbose(True)
inputs = torch.rand(4, 20, 10) # batch size, sequence length, channel
outputs = model(inputs)
abstract_space = model.abstract_search_space
print(
"The abstract search space for SuperAttention is:\n{:}".format(
abstract_space
)
)
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candidate(abstract_child)
outputs = model(inputs)
output_shape = (4, 20, abstract_child["proj"]["_out_features"].value)
self.assertEqual(tuple(outputs.shape), output_shape)