Add SuperAttention
This commit is contained in:
		
							
								
								
									
										0
									
								
								lib/trade_models/naive_v1_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/naive_v1_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										0
									
								
								lib/trade_models/naive_v2_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/naive_v2_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										0
									
								
								lib/trade_models/quant_transformer.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/quant_transformer.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										6
									
								
								lib/trade_models/transformers.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										6
									
								
								lib/trade_models/transformers.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							| @@ -1,6 +1,6 @@ | |||||||
| ################################################## | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ################################################## | ##################################################### | ||||||
| from __future__ import division | from __future__ import division | ||||||
| from __future__ import print_function | from __future__ import print_function | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										155
									
								
								lib/xlayers/super_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								lib/xlayers/super_attention.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,155 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | from __future__ import division | ||||||
|  | from __future__ import print_function | ||||||
|  |  | ||||||
|  | import math | ||||||
|  | from functools import partial | ||||||
|  | from typing import Optional, Text | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  |  | ||||||
|  |  | ||||||
|  | import spaces | ||||||
|  | from .super_module import SuperModule | ||||||
|  | from .super_module import IntSpaceType | ||||||
|  | from .super_module import BoolSpaceType | ||||||
|  | from .super_linear import SuperLinear | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperAttention(SuperModule): | ||||||
|  |     """The super model for attention layer.""" | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         input_dim: IntSpaceType, | ||||||
|  |         proj_dim: IntSpaceType, | ||||||
|  |         num_heads: IntSpaceType, | ||||||
|  |         qkv_bias: BoolSpaceType = False, | ||||||
|  |         attn_drop: float = 0.0, | ||||||
|  |         proj_drop: float = 0.0, | ||||||
|  |     ): | ||||||
|  |         super(SuperAttention, self).__init__() | ||||||
|  |         self._input_dim = input_dim | ||||||
|  |         self._proj_dim = proj_dim | ||||||
|  |         self._num_heads = num_heads | ||||||
|  |         self._qkv_bias = qkv_bias | ||||||
|  |         # head_dim = dim // num_heads | ||||||
|  |         # self.scale = qk_scale or math.sqrt(head_dim) | ||||||
|  |  | ||||||
|  |         # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||||||
|  |         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|  |         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|  |         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||||
|  |  | ||||||
|  |         self.attn_drop = nn.Dropout(attn_drop) | ||||||
|  |         self.proj = SuperLinear(input_dim, proj_dim) | ||||||
|  |         self.proj_drop = nn.Dropout(proj_drop) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def num_heads(self): | ||||||
|  |         return spaces.get_max(self._num_heads) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def input_dim(self): | ||||||
|  |         return spaces.get_max(self._input_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def proj_dim(self): | ||||||
|  |         return spaces.get_max(self._proj_dim) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         space_q = self.q_fc.abstract_search_space | ||||||
|  |         space_k = self.k_fc.abstract_search_space | ||||||
|  |         space_v = self.v_fc.abstract_search_space | ||||||
|  |         space_proj = self.proj.abstract_search_space | ||||||
|  |         if not spaces.is_determined(self._num_heads): | ||||||
|  |             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) | ||||||
|  |         if not spaces.is_determined(space_q): | ||||||
|  |             root_node.append("q_fc", space_q) | ||||||
|  |         if not spaces.is_determined(space_k): | ||||||
|  |             root_node.append("k_fc", space_k) | ||||||
|  |         if not spaces.is_determined(space_v): | ||||||
|  |             root_node.append("v_fc", space_v) | ||||||
|  |         if not spaces.is_determined(space_proj): | ||||||
|  |             root_node.append("proj", space_proj) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||||
|  |         super(SuperAttention, self).apply_candidate(abstract_child) | ||||||
|  |         if "q_fc" in abstract_child: | ||||||
|  |             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||||
|  |         if "k_fc" in abstract_child: | ||||||
|  |             self.k_fc.apply_candidate(abstract_child["k_fc"]) | ||||||
|  |         if "v_fc" in abstract_child: | ||||||
|  |             self.v_fc.apply_candidate(abstract_child["v_fc"]) | ||||||
|  |         if "proj" in abstract_child: | ||||||
|  |             self.proj.apply_candidate(abstract_child["proj"]) | ||||||
|  |  | ||||||
|  |     def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor: | ||||||
|  |         B, N, C = input.shape | ||||||
|  |         q = self.q_fc(input) | ||||||
|  |         k = self.k_fc(input) | ||||||
|  |         v = self.v_fc(input) | ||||||
|  |         if num_head > C: | ||||||
|  |             raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C)) | ||||||
|  |         head_dim = C // num_head | ||||||
|  |         # process the first [num_head * head_dim] part | ||||||
|  |         q_v1 = ( | ||||||
|  |             q[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B, N, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         k_v1 = ( | ||||||
|  |             k[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B, N, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         v_v1 = ( | ||||||
|  |             v[:, :, : num_head * head_dim] | ||||||
|  |             .reshape(B, N, num_head, head_dim) | ||||||
|  |             .permute(0, 2, 1, 3) | ||||||
|  |         ) | ||||||
|  |         attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) | ||||||
|  |         attn_v1 = attn_v1.softmax(dim=-1) | ||||||
|  |         attn_v1 = self.attn_drop(attn_v1) | ||||||
|  |         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||||
|  |         if C == head_dim * num_head: | ||||||
|  |             feats = feats_v1 | ||||||
|  |         else:  # The channels can not be divided by num_head, the remainder forms an additional head | ||||||
|  |             q_v2 = q[:, :, num_head * head_dim :] | ||||||
|  |             k_v2 = k[:, :, num_head * head_dim :] | ||||||
|  |             v_v2 = v[:, :, num_head * head_dim :] | ||||||
|  |             attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) | ||||||
|  |             attn_v2 = attn_v2.softmax(dim=-1) | ||||||
|  |             attn_v2 = self.attn_drop(attn_v2) | ||||||
|  |             feats_v2 = attn_v2 @ v_v2 | ||||||
|  |             feats = torch.cat([feats_v1, feats_v2], dim=-1) | ||||||
|  |         return feats | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check the num_heads: | ||||||
|  |         if not spaces.is_determined(self._num_heads): | ||||||
|  |             num_heads = self.abstract_child["_num_heads"].value | ||||||
|  |         else: | ||||||
|  |             num_heads = spaces.get_determined_value(self._num_heads) | ||||||
|  |         feats = self.forward_qkv(input, num_heads) | ||||||
|  |         outs = self.proj(feats) | ||||||
|  |         outs = self.proj_drop(outs) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         feats = self.forward_qkv(input, self.num_heads) | ||||||
|  |         outs = self.proj(feats) | ||||||
|  |         outs = self.proj_drop(outs) | ||||||
|  |         return outs | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "input_dim={:}, proj_dim={:}, num_heads={:}".format( | ||||||
|  |             self._input_dim, self._proj_dim, self._num_heads | ||||||
|  |         ) | ||||||
| @@ -5,3 +5,4 @@ from .super_module import SuperRunMode | |||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
| from .super_linear import SuperMLP | from .super_linear import SuperMLP | ||||||
|  | from .super_attention import SuperAttention | ||||||
|   | |||||||
| @@ -6,14 +6,12 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| import math | import math | ||||||
| from typing import Optional, Union, Callable | from typing import Optional, Callable | ||||||
|  |  | ||||||
| import spaces | import spaces | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_module import SuperRunMode | from .super_module import IntSpaceType | ||||||
|  | from .super_module import BoolSpaceType | ||||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] |  | ||||||
| BoolSpaceType = Union[bool, spaces.Categorical] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperLinear(SuperModule): | class SuperLinear(SuperModule): | ||||||
|   | |||||||
| @@ -1,13 +1,18 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
|  | from typing import Optional, Union, Callable | ||||||
|  | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from enum import Enum | from enum import Enum | ||||||
|  |  | ||||||
| import spaces | import spaces | ||||||
|  |  | ||||||
|  | IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||||
|  | BoolSpaceType = Union[bool, spaces.Categorical] | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperRunMode(Enum): | class SuperRunMode(Enum): | ||||||
|     """This class defines the enumerations for Super Model Running Mode.""" |     """This class defines the enumerations for Super Model Running Mode.""" | ||||||
| @@ -24,6 +29,7 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|         super(SuperModule, self).__init__() |         super(SuperModule, self).__init__() | ||||||
|         self._super_run_type = SuperRunMode.Default |         self._super_run_type = SuperRunMode.Default | ||||||
|         self._abstract_child = None |         self._abstract_child = None | ||||||
|  |         self._verbose = False | ||||||
|  |  | ||||||
|     def set_super_run_type(self, super_run_type): |     def set_super_run_type(self, super_run_type): | ||||||
|         def _reset_super_run(m): |         def _reset_super_run(m): | ||||||
| @@ -32,6 +38,13 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|  |  | ||||||
|         self.apply(_reset_super_run) |         self.apply(_reset_super_run) | ||||||
|  |  | ||||||
|  |     def apply_verbose(self, verbose): | ||||||
|  |         def _reset_verbose(m): | ||||||
|  |             if isinstance(m, SuperModule): | ||||||
|  |                 m._verbose = verbose | ||||||
|  |  | ||||||
|  |         self.apply(_reset_verbose) | ||||||
|  |  | ||||||
|     def apply_candidate(self, abstract_child): |     def apply_candidate(self, abstract_child): | ||||||
|         if not isinstance(abstract_child, spaces.VirtualNode): |         if not isinstance(abstract_child, spaces.VirtualNode): | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
| @@ -51,6 +64,10 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|     def abstract_child(self): |     def abstract_child(self): | ||||||
|         return self._abstract_child |         return self._abstract_child | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def verbose(self): | ||||||
|  |         return self._verbose | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def forward_raw(self, *inputs): |     def forward_raw(self, *inputs): | ||||||
|         """Use the largest candidate for forward. Similar to the original PyTorch model.""" |         """Use the largest candidate for forward. Similar to the original PyTorch model.""" | ||||||
| @@ -60,12 +77,41 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|     def forward_candidate(self, *inputs): |     def forward_candidate(self, *inputs): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def name_with_id(self): | ||||||
|  |         return "name={:}, id={:}".format(self.__class__.__name__, id(self)) | ||||||
|  |  | ||||||
|  |     def get_shape_str(self, tensors): | ||||||
|  |         if isinstance(tensors, (list, tuple)): | ||||||
|  |             shapes = [self.get_shape_str(tensor) for tensor in tensors] | ||||||
|  |             if len(shapes) == 1: | ||||||
|  |                 return shapes[0] | ||||||
|  |             else: | ||||||
|  |                 return ", ".join(shapes) | ||||||
|  |         elif isinstance(tensors, (torch.Tensor, nn.Parameter)): | ||||||
|  |             return str(tuple(tensors.shape)) | ||||||
|  |         else: | ||||||
|  |             raise TypeError("Invalid input type: {:}.".format(type(tensors))) | ||||||
|  |  | ||||||
|     def forward(self, *inputs): |     def forward(self, *inputs): | ||||||
|  |         if self.verbose: | ||||||
|  |             print( | ||||||
|  |                 "[{:}] inputs shape: {:}".format( | ||||||
|  |                     self.name_with_id, self.get_shape_str(inputs) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|         if self.super_run_type == SuperRunMode.FullModel: |         if self.super_run_type == SuperRunMode.FullModel: | ||||||
|             return self.forward_raw(*inputs) |             outputs = self.forward_raw(*inputs) | ||||||
|         elif self.super_run_type == SuperRunMode.Candidate: |         elif self.super_run_type == SuperRunMode.Candidate: | ||||||
|             return self.forward_candidate(*inputs) |             outputs = self.forward_candidate(*inputs) | ||||||
|         else: |         else: | ||||||
|             raise ModeError( |             raise ModeError( | ||||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) |                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||||
|             ) |             ) | ||||||
|  |         if self.verbose: | ||||||
|  |             print( | ||||||
|  |                 "[{:}] outputs shape: {:}".format( | ||||||
|  |                     self.name_with_id, self.get_shape_str(outputs) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         return outputs | ||||||
|   | |||||||
| @@ -26,6 +26,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         bias = spaces.Categorical(True, False) |         bias = spaces.Categorical(True, False) | ||||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) |         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||||
|         print("The simple super linear module is:\n{:}".format(model)) |         print("The simple super linear module is:\n{:}".format(model)) | ||||||
|  |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|         print(model.super_run_type) |         print(model.super_run_type) | ||||||
|         self.assertTrue(model.bias) |         self.assertTrue(model.bias) | ||||||
| @@ -55,6 +56,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         out_features = spaces.Categorical(24, 36, 48) |         out_features = spaces.Categorical(24, 36, 48) | ||||||
|         mlp = super_core.SuperMLP(10, hidden_features, out_features) |         mlp = super_core.SuperMLP(10, hidden_features, out_features) | ||||||
|         print(mlp) |         print(mlp) | ||||||
|  |         mlp.apply_verbose(True) | ||||||
|         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) |         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||||
|  |  | ||||||
|         inputs = torch.rand(4, 10) |         inputs = torch.rand(4, 10) | ||||||
| @@ -85,3 +87,29 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         outputs = mlp(inputs) |         outputs = mlp(inputs) | ||||||
|         output_shape = (4, abstract_child["fc2"]["_out_features"].value) |         output_shape = (4, abstract_child["fc2"]["_out_features"].value) | ||||||
|         self.assertEqual(tuple(outputs.shape), output_shape) |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|  |  | ||||||
|  |     def test_super_attention(self): | ||||||
|  |         proj_dim = spaces.Categorical(12, 24, 36) | ||||||
|  |         num_heads = spaces.Categorical(2, 4, 6) | ||||||
|  |         model = super_core.SuperAttention(10, proj_dim, num_heads) | ||||||
|  |         print(model) | ||||||
|  |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|  |         inputs = torch.rand(4, 20, 10)  # batch size, sequence length, channel | ||||||
|  |         outputs = model(inputs) | ||||||
|  |  | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         print( | ||||||
|  |             "The abstract search space for SuperAttention is:\n{:}".format( | ||||||
|  |                 abstract_space | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         abstract_space.clean_last() | ||||||
|  |         abstract_child = abstract_space.random(reuse_last=True) | ||||||
|  |         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|  |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.apply_candidate(abstract_child) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         output_shape = (4, 20, abstract_child["proj"]["_out_features"].value) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user