Add SuperAttention

This commit is contained in:
D-X-Y 2021-03-20 15:56:37 +08:00
parent 0c56a729ad
commit e023a53c75
9 changed files with 239 additions and 11 deletions

0
lib/trade_models/naive_v1_model.py Executable file → Normal file
View File

0
lib/trade_models/naive_v2_model.py Executable file → Normal file
View File

0
lib/trade_models/quant_transformer.py Executable file → Normal file
View File

6
lib/trade_models/transformers.py Executable file → Normal file
View File

@ -1,6 +1,6 @@
################################################## #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
################################################## #####################################################
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -0,0 +1,155 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from __future__ import division
from __future__ import print_function
import math
from functools import partial
from typing import Optional, Text
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
from .super_module import SuperModule
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
from .super_linear import SuperLinear
class SuperAttention(SuperModule):
"""The super model for attention layer."""
def __init__(
self,
input_dim: IntSpaceType,
proj_dim: IntSpaceType,
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super(SuperAttention, self).__init__()
self._input_dim = input_dim
self._proj_dim = proj_dim
self._num_heads = num_heads
self._qkv_bias = qkv_bias
# head_dim = dim // num_heads
# self.scale = qk_scale or math.sqrt(head_dim)
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = SuperLinear(input_dim, proj_dim)
self.proj_drop = nn.Dropout(proj_drop)
@property
def num_heads(self):
return spaces.get_max(self._num_heads)
@property
def input_dim(self):
return spaces.get_max(self._input_dim)
@property
def proj_dim(self):
return spaces.get_max(self._proj_dim)
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
space_q = self.q_fc.abstract_search_space
space_k = self.k_fc.abstract_search_space
space_v = self.v_fc.abstract_search_space
space_proj = self.proj.abstract_search_space
if not spaces.is_determined(self._num_heads):
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
if not spaces.is_determined(space_q):
root_node.append("q_fc", space_q)
if not spaces.is_determined(space_k):
root_node.append("k_fc", space_k)
if not spaces.is_determined(space_v):
root_node.append("v_fc", space_v)
if not spaces.is_determined(space_proj):
root_node.append("proj", space_proj)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperAttention, self).apply_candidate(abstract_child)
if "q_fc" in abstract_child:
self.q_fc.apply_candidate(abstract_child["q_fc"])
if "k_fc" in abstract_child:
self.k_fc.apply_candidate(abstract_child["k_fc"])
if "v_fc" in abstract_child:
self.v_fc.apply_candidate(abstract_child["v_fc"])
if "proj" in abstract_child:
self.proj.apply_candidate(abstract_child["proj"])
def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor:
B, N, C = input.shape
q = self.q_fc(input)
k = self.k_fc(input)
v = self.v_fc(input)
if num_head > C:
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
head_dim = C // num_head
# process the first [num_head * head_dim] part
q_v1 = (
q[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
k_v1 = (
k[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
v_v1 = (
v[:, :, : num_head * head_dim]
.reshape(B, N, num_head, head_dim)
.permute(0, 2, 1, 3)
)
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
attn_v1 = attn_v1.softmax(dim=-1)
attn_v1 = self.attn_drop(attn_v1)
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
if C == head_dim * num_head:
feats = feats_v1
else: # The channels can not be divided by num_head, the remainder forms an additional head
q_v2 = q[:, :, num_head * head_dim :]
k_v2 = k[:, :, num_head * head_dim :]
v_v2 = v[:, :, num_head * head_dim :]
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
attn_v2 = attn_v2.softmax(dim=-1)
attn_v2 = self.attn_drop(attn_v2)
feats_v2 = attn_v2 @ v_v2
feats = torch.cat([feats_v1, feats_v2], dim=-1)
return feats
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
# check the num_heads:
if not spaces.is_determined(self._num_heads):
num_heads = self.abstract_child["_num_heads"].value
else:
num_heads = spaces.get_determined_value(self._num_heads)
feats = self.forward_qkv(input, num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
feats = self.forward_qkv(input, self.num_heads)
outs = self.proj(feats)
outs = self.proj_drop(outs)
return outs
def extra_repr(self) -> str:
return "input_dim={:}, proj_dim={:}, num_heads={:}".format(
self._input_dim, self._proj_dim, self._num_heads
)

View File

@ -5,3 +5,4 @@ from .super_module import SuperRunMode
from .super_module import SuperModule from .super_module import SuperModule
from .super_linear import SuperLinear from .super_linear import SuperLinear
from .super_linear import SuperMLP from .super_linear import SuperMLP
from .super_attention import SuperAttention

View File

@ -6,14 +6,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from typing import Optional, Union, Callable from typing import Optional, Callable
import spaces import spaces
from .super_module import SuperModule from .super_module import SuperModule
from .super_module import SuperRunMode from .super_module import IntSpaceType
from .super_module import BoolSpaceType
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class SuperLinear(SuperModule): class SuperLinear(SuperModule):

View File

@ -1,13 +1,18 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
##################################################### #####################################################
import abc import abc
from typing import Optional, Union, Callable
import torch
import torch.nn as nn import torch.nn as nn
from enum import Enum from enum import Enum
import spaces import spaces
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class SuperRunMode(Enum): class SuperRunMode(Enum):
"""This class defines the enumerations for Super Model Running Mode.""" """This class defines the enumerations for Super Model Running Mode."""
@ -24,6 +29,7 @@ class SuperModule(abc.ABC, nn.Module):
super(SuperModule, self).__init__() super(SuperModule, self).__init__()
self._super_run_type = SuperRunMode.Default self._super_run_type = SuperRunMode.Default
self._abstract_child = None self._abstract_child = None
self._verbose = False
def set_super_run_type(self, super_run_type): def set_super_run_type(self, super_run_type):
def _reset_super_run(m): def _reset_super_run(m):
@ -32,6 +38,13 @@ class SuperModule(abc.ABC, nn.Module):
self.apply(_reset_super_run) self.apply(_reset_super_run)
def apply_verbose(self, verbose):
def _reset_verbose(m):
if isinstance(m, SuperModule):
m._verbose = verbose
self.apply(_reset_verbose)
def apply_candidate(self, abstract_child): def apply_candidate(self, abstract_child):
if not isinstance(abstract_child, spaces.VirtualNode): if not isinstance(abstract_child, spaces.VirtualNode):
raise ValueError( raise ValueError(
@ -51,6 +64,10 @@ class SuperModule(abc.ABC, nn.Module):
def abstract_child(self): def abstract_child(self):
return self._abstract_child return self._abstract_child
@property
def verbose(self):
return self._verbose
@abc.abstractmethod @abc.abstractmethod
def forward_raw(self, *inputs): def forward_raw(self, *inputs):
"""Use the largest candidate for forward. Similar to the original PyTorch model.""" """Use the largest candidate for forward. Similar to the original PyTorch model."""
@ -60,12 +77,41 @@ class SuperModule(abc.ABC, nn.Module):
def forward_candidate(self, *inputs): def forward_candidate(self, *inputs):
raise NotImplementedError raise NotImplementedError
@property
def name_with_id(self):
return "name={:}, id={:}".format(self.__class__.__name__, id(self))
def get_shape_str(self, tensors):
if isinstance(tensors, (list, tuple)):
shapes = [self.get_shape_str(tensor) for tensor in tensors]
if len(shapes) == 1:
return shapes[0]
else:
return ", ".join(shapes)
elif isinstance(tensors, (torch.Tensor, nn.Parameter)):
return str(tuple(tensors.shape))
else:
raise TypeError("Invalid input type: {:}.".format(type(tensors)))
def forward(self, *inputs): def forward(self, *inputs):
if self.verbose:
print(
"[{:}] inputs shape: {:}".format(
self.name_with_id, self.get_shape_str(inputs)
)
)
if self.super_run_type == SuperRunMode.FullModel: if self.super_run_type == SuperRunMode.FullModel:
return self.forward_raw(*inputs) outputs = self.forward_raw(*inputs)
elif self.super_run_type == SuperRunMode.Candidate: elif self.super_run_type == SuperRunMode.Candidate:
return self.forward_candidate(*inputs) outputs = self.forward_candidate(*inputs)
else: else:
raise ModeError( raise ModeError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type) "Unknown Super Model Run Mode: {:}".format(self.super_run_type)
) )
if self.verbose:
print(
"[{:}] outputs shape: {:}".format(
self.name_with_id, self.get_shape_str(outputs)
)
)
return outputs

View File

@ -26,6 +26,7 @@ class TestSuperLinear(unittest.TestCase):
bias = spaces.Categorical(True, False) bias = spaces.Categorical(True, False)
model = super_core.SuperLinear(10, out_features, bias=bias) model = super_core.SuperLinear(10, out_features, bias=bias)
print("The simple super linear module is:\n{:}".format(model)) print("The simple super linear module is:\n{:}".format(model))
model.apply_verbose(True)
print(model.super_run_type) print(model.super_run_type)
self.assertTrue(model.bias) self.assertTrue(model.bias)
@ -55,6 +56,7 @@ class TestSuperLinear(unittest.TestCase):
out_features = spaces.Categorical(24, 36, 48) out_features = spaces.Categorical(24, 36, 48)
mlp = super_core.SuperMLP(10, hidden_features, out_features) mlp = super_core.SuperMLP(10, hidden_features, out_features)
print(mlp) print(mlp)
mlp.apply_verbose(True)
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
inputs = torch.rand(4, 10) inputs = torch.rand(4, 10)
@ -85,3 +87,29 @@ class TestSuperLinear(unittest.TestCase):
outputs = mlp(inputs) outputs = mlp(inputs)
output_shape = (4, abstract_child["fc2"]["_out_features"].value) output_shape = (4, abstract_child["fc2"]["_out_features"].value)
self.assertEqual(tuple(outputs.shape), output_shape) self.assertEqual(tuple(outputs.shape), output_shape)
def test_super_attention(self):
proj_dim = spaces.Categorical(12, 24, 36)
num_heads = spaces.Categorical(2, 4, 6)
model = super_core.SuperAttention(10, proj_dim, num_heads)
print(model)
model.apply_verbose(True)
inputs = torch.rand(4, 20, 10) # batch size, sequence length, channel
outputs = model(inputs)
abstract_space = model.abstract_search_space
print(
"The abstract search space for SuperAttention is:\n{:}".format(
abstract_space
)
)
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candidate(abstract_child)
outputs = model(inputs)
output_shape = (4, 20, abstract_child["proj"]["_out_features"].value)
self.assertEqual(tuple(outputs.shape), output_shape)