diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 3053656..07b8a45 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -40,6 +40,7 @@ jobs: - name: Test Search Space run: | python -m pip install pytest numpy + python -m pip install parameterized echo $PWD echo "Show what we have here:" ls diff --git a/.github/workflows/super_model_test.yml b/.github/workflows/super_model_test.yml index 7a0e9be..0bc049c 100644 --- a/.github/workflows/super_model_test.yml +++ b/.github/workflows/super_model_test.yml @@ -27,6 +27,7 @@ jobs: - name: Test Super Model run: | python -m pip install pytest numpy + python -m pip install parameterized python -m pip install torch torchvision torchaudio python -m pytest ./tests/test_super_model.py -s shell: bash diff --git a/lib/xlayers/super_attention.py b/lib/xlayers/super_attention.py index e072700..485625f 100644 --- a/lib/xlayers/super_attention.py +++ b/lib/xlayers/super_attention.py @@ -29,8 +29,8 @@ class SuperAttention(SuperModule): proj_dim: IntSpaceType, num_heads: IntSpaceType, qkv_bias: BoolSpaceType = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, + attn_drop: Optional[float] = None, + proj_drop: Optional[float] = None, ): super(SuperAttention, self).__init__() self._input_dim = input_dim @@ -45,9 +45,9 @@ class SuperAttention(SuperModule): self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) + self.attn_drop = nn.Dropout(attn_drop or 0.0) self.proj = SuperLinear(input_dim, proj_dim) - self.proj_drop = nn.Dropout(proj_drop) + self.proj_drop = nn.Dropout(proj_drop or 0.0) @property def num_heads(self): diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 7d799cf..36a14a5 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -4,5 +4,7 @@ from .super_module import SuperRunMode from .super_module import SuperModule from .super_linear import SuperLinear -from .super_linear import SuperMLP +from .super_linear import SuperMLPv1, SuperMLPv2 +from .super_norm import SuperLayerNorm1D from .super_attention import SuperAttention +from .super_transformer import SuperTransformerEncoderLayer diff --git a/lib/xlayers/super_linear.py b/lib/xlayers/super_linear.py index 97c411b..78f280e 100644 --- a/lib/xlayers/super_linear.py +++ b/lib/xlayers/super_linear.py @@ -113,7 +113,7 @@ class SuperLinear(SuperModule): ) -class SuperMLP(SuperModule): +class SuperMLPv1(SuperModule): """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" def __init__( @@ -124,7 +124,7 @@ class SuperMLP(SuperModule): act_layer: Callable[[], nn.Module] = nn.GELU, drop: Optional[float] = None, ): - super(SuperMLP, self).__init__() + super(SuperMLPv1, self).__init__() self._in_features = in_features self._hidden_features = hidden_features self._out_features = out_features @@ -146,20 +146,17 @@ class SuperMLP(SuperModule): return root_node def apply_candidate(self, abstract_child: spaces.VirtualNode): - super(SuperMLP, self).apply_candidate(abstract_child) + super(SuperMLPv1, self).apply_candidate(abstract_child) if "fc1" in abstract_child: self.fc1.apply_candidate(abstract_child["fc1"]) if "fc2" in abstract_child: self.fc2.apply_candidate(abstract_child["fc2"]) def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: - return self._unified_forward(input) + return self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: - return self._unified_forward(input) - - def _unified_forward(self, x): - x = self.fc1(x) + x = self.fc1(input) x = self.act(x) x = self.drop(x) x = self.fc2(x) @@ -173,3 +170,137 @@ class SuperMLP(SuperModule): self._out_features, self._drop_rate, ) + + +class SuperMLPv2(SuperModule): + """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" + + def __init__( + self, + in_features: IntSpaceType, + hidden_multiplier: IntSpaceType, + out_features: IntSpaceType, + act_layer: Callable[[], nn.Module] = nn.GELU, + drop: Optional[float] = None, + ): + super(SuperMLPv2, self).__init__() + self._in_features = in_features + self._hidden_multiplier = hidden_multiplier + self._out_features = out_features + self._drop_rate = drop + self._params = nn.ParameterDict({}) + + self._create_linear( + "fc1", self.in_features, int(self.in_features * self.hidden_multiplier) + ) + self._create_linear( + "fc2", int(self.in_features * self.hidden_multiplier), self.out_features + ) + self.act = act_layer() + self.drop = nn.Dropout(drop or 0.0) + self.reset_parameters() + + @property + def in_features(self): + return spaces.get_max(self._in_features) + + @property + def hidden_multiplier(self): + return spaces.get_max(self._hidden_multiplier) + + @property + def out_features(self): + return spaces.get_max(self._out_features) + + def _create_linear(self, name, inC, outC): + self._params["{:}_super_weight".format(name)] = torch.nn.Parameter( + torch.Tensor(outC, inC) + ) + self._params["{:}_super_bias".format(name)] = torch.nn.Parameter( + torch.Tensor(outC) + ) + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self._params["fc1_super_weight"], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self._params["fc2_super_weight"], a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self._params["fc1_super_weight"] + ) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self._params["fc1_super_bias"], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self._params["fc2_super_weight"] + ) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self._params["fc2_super_bias"], -bound, bound) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + if not spaces.is_determined(self._in_features): + root_node.append( + "_in_features", self._in_features.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._hidden_multiplier): + root_node.append( + "_hidden_multiplier", self._hidden_multiplier.abstract(reuse_last=True) + ) + if not spaces.is_determined(self._out_features): + root_node.append( + "_out_features", self._out_features.abstract(reuse_last=True) + ) + return root_node + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + if not spaces.is_determined(self._in_features): + expected_input_dim = self.abstract_child["_in_features"].value + else: + expected_input_dim = spaces.get_determined_value(self._in_features) + if input.size(-1) != expected_input_dim: + raise ValueError( + "Expect the input dim of {:} instead of {:}".format( + expected_input_dim, input.size(-1) + ) + ) + # create the weight and bias matrix for fc1 + if not spaces.is_determined(self._hidden_multiplier): + hmul = self.abstract_child["_hidden_multiplier"].value * expected_input_dim + else: + hmul = spaces.get_determined_value(self._hidden_multiplier) + hidden_dim = int(expected_input_dim * hmul) + _fc1_weight = self._params["fc1_super_weight"][:hidden_dim, :expected_input_dim] + _fc1_bias = self._params["fc1_super_bias"][:hidden_dim] + x = F.linear(input, _fc1_weight, _fc1_bias) + x = self.act(x) + x = self.drop(x) + # create the weight and bias matrix for fc2 + if not spaces.is_determined(self._out_features): + out_dim = self.abstract_child["_out_features"].value + else: + out_dim = spaces.get_determined_value(self._out_features) + _fc2_weight = self._params["fc2_super_weight"][:out_dim, :hidden_dim] + _fc2_bias = self._params["fc2_super_bias"][:out_dim] + x = F.linear(x, _fc2_weight, _fc2_bias) + x = self.drop(x) + return x + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + x = F.linear( + input, self._params["fc1_super_weight"], self._params["fc1_super_bias"] + ) + x = self.act(x) + x = self.drop(x) + x = F.linear( + x, self._params["fc2_super_weight"], self._params["fc2_super_bias"] + ) + x = self.drop(x) + return x + + def extra_repr(self) -> str: + return "in_features={:}, hidden_multiplier={:}, out_features={:}, drop={:}, fc1 -> act -> drop -> fc2 -> drop,".format( + self._in_features, + self._hidden_multiplier, + self._out_features, + self._drop_rate, + ) diff --git a/lib/xlayers/super_norm.py b/lib/xlayers/super_norm.py new file mode 100644 index 0000000..3db1471 --- /dev/null +++ b/lib/xlayers/super_norm.py @@ -0,0 +1,82 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math +from typing import Optional, Callable + +import spaces +from .super_module import SuperModule +from .super_module import IntSpaceType +from .super_module import BoolSpaceType + + +class SuperLayerNorm1D(SuperModule): + """Super Layer Norm.""" + + def __init__( + self, dim: IntSpaceType, eps: float = 1e-5, elementwise_affine: bool = True + ) -> None: + super(SuperLayerNorm1D, self).__init__() + self._in_dim = dim + self._eps = eps + self._elementwise_affine = elementwise_affine + if self._elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(self.in_dim)) + self.bias = nn.Parameter(torch.Tensor(self.in_dim)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + self.reset_parameters() + + @property + def in_dim(self): + return spaces.get_max(self._in_dim) + + @property + def eps(self): + return self._eps + + def reset_parameters(self) -> None: + if self._elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + if not spaces.is_determined(self._in_dim): + root_node.append("_in_dim", self._in_dim.abstract(reuse_last=True)) + return root_node + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + # check inputs -> + if not spaces.is_determined(self._in_dim): + expected_input_dim = self.abstract_child["_in_dim"].value + else: + expected_input_dim = spaces.get_determined_value(self._in_dim) + if input.size(-1) != expected_input_dim: + raise ValueError( + "Expect the input dim of {:} instead of {:}".format( + expected_input_dim, input.size(-1) + ) + ) + if self._elementwise_affine: + weight = self.weight[:expected_input_dim] + bias = self.bias[:expected_input_dim] + else: + weight, bias = None, None + return F.layer_norm(input, (expected_input_dim,), weight, bias, self.eps) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + return F.layer_norm(input, (self.in_dim,), self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return "{in_dim}, eps={eps}, " "elementwise_affine={elementwise_affine}".format( + in_dim=self._in_dim, + eps=self._eps, + elementwise_affine=self._elementwise_affine, + ) diff --git a/lib/xlayers/super_transformer.py b/lib/xlayers/super_transformer.py new file mode 100644 index 0000000..aa11511 --- /dev/null +++ b/lib/xlayers/super_transformer.py @@ -0,0 +1,100 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +from __future__ import division +from __future__ import print_function + +import math +from functools import partial +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import spaces +from .super_module import IntSpaceType +from .super_module import BoolSpaceType +from .super_module import SuperModule +from .super_linear import SuperMLPv2 +from .super_norm import SuperLayerNorm1D +from .super_attention import SuperAttention + + +class SuperTransformerEncoderLayer(SuperModule): + """TransformerEncoderLayer is made up of self-attn and feedforward network. + This is a super model for TransformerEncoderLayer that can support search for the transformer encoder layer. + + Reference: + - Paper: Attention Is All You Need, NeurIPS 2017 + - PyTorch Implementation: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer + + Details: + MHA -> residual -> norm -> MLP -> residual -> norm + """ + + def __init__( + self, + input_dim: IntSpaceType, + output_dim: IntSpaceType, + num_heads: IntSpaceType, + qkv_bias: BoolSpaceType = False, + mlp_hidden_multiplier: IntSpaceType = 4, + drop: Optional[float] = None, + act_layer: Callable[[], nn.Module] = nn.GELU, + ): + super(SuperTransformerEncoderLayer, self).__init__() + self.mha = SuperAttention( + input_dim, + input_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=drop, + proj_drop=drop, + ) + self.drop1 = nn.Dropout(drop or 0.0) + self.norm1 = SuperLayerNorm1D(input_dim) + self.mlp = SuperMLPv2( + input_dim, + hidden_multiplier=mlp_hidden_multiplier, + out_features=output_dim, + act_layer=act_layer, + drop=drop, + ) + self.drop2 = nn.Dropout(drop or 0.0) + self.norm2 = SuperLayerNorm1D(output_dim) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + xdict = dict( + mha=self.mha.abstract_search_space, + norm1=self.norm1.abstract_search_space, + mlp=self.mlp.abstract_search_space, + norm2=self.norm2.abstract_search_space, + ) + for key, space in xdict.items(): + if not spaces.is_determined(space): + root_node.append(key, space) + return root_node + + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperTransformerEncoderLayer, self).apply_candidate(abstract_child) + valid_keys = ["mha", "norm1", "mlp", "norm2"] + for key in valid_keys: + if key in abstract_child: + getattr(self, key).apply_candidate(abstract_child[key]) + + def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_raw(input) + + def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + # multi-head attention + x = self.mha(input) + x = x + self.drop1(x) + x = self.norm1(x) + # feed-forward layer + x = self.mlp(x) + x = x + self.drop2(x) + x = self.norm2(x) + return x diff --git a/notebooks/spaces/random-search-mlp.ipynb b/notebooks/spaces/random-search-mlp.ipynb deleted file mode 100644 index 18a9324..0000000 --- a/notebooks/spaces/random-search-mlp.ipynb +++ /dev/null @@ -1,93 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n" - ] - } - ], - "source": [ - "#####################################################\n", - "# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n", - "#####################################################\n", - "import abc, os, sys\n", - "from pathlib import Path\n", - "\n", - "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", - "\n", - "lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n", - "print(\"library path: {:}\".format(lib_dir))\n", - "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", - "if str(lib_dir) not in sys.path:\n", - " sys.path.insert(0, str(lib_dir))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "default", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m~/Desktop/XAutoDL/notebooks/spaces\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mout_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspaces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCategorical\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m12\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m24\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m36\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspaces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCategorical\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSuperLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/XAutoDL/lib/layers/super_mlp.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, in_features, out_features, bias)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mBoolSpaceType\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m ) -> None:\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mSuperLinear\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;31m# the raw input args\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Desktop/XAutoDL/lib/layers/super_module.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mSuperModule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_super_run_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSuperRunMode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mabc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabstractmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/lib/python3.8/enum.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_member_map_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 340\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 342\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: default" - ] - } - ], - "source": [ - "# Test the Linear layer\n", - "import spaces\n", - "from layers.super_core import SuperLinear\n", - "from layers.super_module import SuperRunMode\n", - "\n", - "out_features = spaces.Categorical(12, 24, 36)\n", - "bias = spaces.Categorical(True, False)\n", - "model = SuperLinear(10, out_features, bias=bias)\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/spaces/random-search-transformer.ipynb b/notebooks/spaces/random-search-transformer.ipynb new file mode 100644 index 0000000..688b4f3 --- /dev/null +++ b/notebooks/spaces/random-search-transformer.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n" + ] + } + ], + "source": [ + "#####################################################\n", + "# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n", + "#####################################################\n", + "import abc, os, sys\n", + "from pathlib import Path\n", + "\n", + "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", + "\n", + "lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n", + "print(\"library path: {:}\".format(lib_dir))\n", + "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", + "if str(lib_dir) not in sys.path:\n", + " sys.path.insert(0, str(lib_dir))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.7.0\n", + "True\n", + "OrderedDict()\n", + "OrderedDict()\n", + "set()\n", + "OrderedDict()\n", + "OrderedDict()\n", + "OrderedDict()\n", + "OrderedDict()\n", + "OrderedDict()\n", + "OrderedDict()\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py:551: UserWarning: Setting attributes on ParameterDict is not supported.\n", + " warnings.warn(\"Setting attributes on ParameterDict is not supported.\")\n" + ] + } + ], + "source": [ + "# Test the Linear layer\n", + "import spaces\n", + "import torch\n", + "from xlayers import super_core\n", + "\n", + "print(torch.__version__)\n", + "mlp = super_core.SuperMLPv2(10, 12, 32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_super_att.py b/tests/test_super_att.py new file mode 100644 index 0000000..3886f8d --- /dev/null +++ b/tests/test_super_att.py @@ -0,0 +1,71 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest ./tests/test_super_model.py -s # +##################################################### +import sys, random +import unittest +from parameterized import parameterized +import pytest +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("library path: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +import torch +from xlayers import super_core +import spaces + + +class TestSuperAttention(unittest.TestCase): + """Test the super attention layer.""" + + def _internal_func(self, inputs, model): + outputs = model(inputs) + abstract_space = model.abstract_search_space + print( + "The abstract search space for SuperAttention is:\n{:}".format( + abstract_space + ) + ) + abstract_space.clean_last() + abstract_child = abstract_space.random(reuse_last=True) + print("The abstract child program is:\n{:}".format(abstract_child)) + model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.apply_candidate(abstract_child) + outputs = model(inputs) + return abstract_child, outputs + + def test_super_attention(self): + proj_dim = spaces.Categorical(12, 24, 36) + num_heads = spaces.Categorical(2, 4, 6) + model = super_core.SuperAttention(10, proj_dim, num_heads) + print(model) + model.apply_verbose(True) + + inputs = torch.rand(4, 20, 10) # batch size, sequence length, channel + abstract_child, outputs = self._internal_func(inputs, model) + output_shape = (4, 20, abstract_child["proj"]["_out_features"].value) + self.assertEqual(tuple(outputs.shape), output_shape) + + @parameterized.expand([[6], [12], [24], [48]]) + def test_transformer_encoder(self, input_dim): + output_dim = spaces.Categorical(12, 24, 36) + model = super_core.SuperTransformerEncoderLayer( + input_dim, + output_dim=output_dim, + num_heads=spaces.Categorical(2, 4, 6), + mlp_hidden_multiplier=spaces.Categorical(1, 2, 4), + ) + print(model) + model.apply_verbose(True) + inputs = torch.rand(4, 20, input_dim) + abstract_child, outputs = self._internal_func(inputs, model) + output_shape = ( + 4, + 20, + output_dim.abstract(reuse_last=True).random(reuse_last=True).value, + ) + self.assertEqual(tuple(outputs.shape), output_shape) diff --git a/tests/test_super_model.py b/tests/test_super_model.py index 88be2fe..de89a2c 100644 --- a/tests/test_super_model.py +++ b/tests/test_super_model.py @@ -51,10 +51,10 @@ class TestSuperLinear(unittest.TestCase): outputs = model(inputs) self.assertEqual(tuple(outputs.shape), output_shape) - def test_super_mlp(self): + def test_super_mlp_v1(self): hidden_features = spaces.Categorical(12, 24, 36) out_features = spaces.Categorical(24, 36, 48) - mlp = super_core.SuperMLP(10, hidden_features, out_features) + mlp = super_core.SuperMLPv1(10, hidden_features, out_features) print(mlp) mlp.apply_verbose(True) self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) @@ -64,7 +64,9 @@ class TestSuperLinear(unittest.TestCase): self.assertEqual(tuple(outputs.shape), (4, 48)) abstract_space = mlp.abstract_search_space - print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space)) + print( + "The abstract search space for SuperMLPv1 is:\n{:}".format(abstract_space) + ) self.assertEqual( abstract_space["fc1"]["_out_features"], abstract_space["fc2"]["_in_features"], @@ -88,28 +90,28 @@ class TestSuperLinear(unittest.TestCase): output_shape = (4, abstract_child["fc2"]["_out_features"].value) self.assertEqual(tuple(outputs.shape), output_shape) - def test_super_attention(self): - proj_dim = spaces.Categorical(12, 24, 36) - num_heads = spaces.Categorical(2, 4, 6) - model = super_core.SuperAttention(10, proj_dim, num_heads) - print(model) - model.apply_verbose(True) + def test_super_mlp_v2(self): + hidden_multiplier = spaces.Categorical(1.0, 2.0, 3.0) + out_features = spaces.Categorical(24, 36, 48) + mlp = super_core.SuperMLPv2(10, hidden_multiplier, out_features) + print(mlp) + mlp.apply_verbose(True) - inputs = torch.rand(4, 20, 10) # batch size, sequence length, channel - outputs = model(inputs) + inputs = torch.rand(4, 10) + outputs = mlp(inputs) + self.assertEqual(tuple(outputs.shape), (4, 48)) - abstract_space = model.abstract_search_space + abstract_space = mlp.abstract_search_space print( - "The abstract search space for SuperAttention is:\n{:}".format( - abstract_space - ) + "The abstract search space for SuperMLPv2 is:\n{:}".format(abstract_space) ) + abstract_space.clean_last() abstract_child = abstract_space.random(reuse_last=True) print("The abstract child program is:\n{:}".format(abstract_child)) - model.set_super_run_type(super_core.SuperRunMode.Candidate) - model.apply_candidate(abstract_child) - outputs = model(inputs) - output_shape = (4, 20, abstract_child["proj"]["_out_features"].value) + mlp.set_super_run_type(super_core.SuperRunMode.Candidate) + mlp.apply_candidate(abstract_child) + outputs = mlp(inputs) + output_shape = (4, abstract_child["_out_features"].value) self.assertEqual(tuple(outputs.shape), output_shape)