diff --git a/.github/workflows/super_model_test.yml b/.github/workflows/super_model_test.yml index 0bc049c..9723ebc 100644 --- a/.github/workflows/super_model_test.yml +++ b/.github/workflows/super_model_test.yml @@ -29,5 +29,5 @@ jobs: python -m pip install pytest numpy python -m pip install parameterized python -m pip install torch torchvision torchaudio - python -m pytest ./tests/test_super_model.py -s + python -m pytest ./tests/test_super_*.py -s shell: bash diff --git a/.latent-data/init-configs/README.md b/.latent-data/init-configs/README.md index 9a22786..742e459 100644 --- a/.latent-data/init-configs/README.md +++ b/.latent-data/init-configs/README.md @@ -16,3 +16,8 @@ python -m black __init__.py -l 120 pytest -W ignore::DeprecationWarning qlib/tests/test_all_pipeline.py ``` + + +``` +conda update --all +``` diff --git a/lib/xlayers/super_container.py b/lib/xlayers/super_container.py new file mode 100644 index 0000000..e479a0c --- /dev/null +++ b/lib/xlayers/super_container.py @@ -0,0 +1,111 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import torch + +from itertools import islice +import operator + +from collections import OrderedDict +from typing import Optional, Union, Callable, TypeVar, Iterator + +import spaces +from .super_module import SuperModule + + +T = TypeVar("T", bound=SuperModule) + + +class SuperSequential(SuperModule): + """A sequential container wrapped with 'Super' ability. + + Modules will be added to it in the order they are passed in the constructor. + Alternatively, an ordered dict of modules can also be passed in. + To make it easier to understand, here is a small example:: + # Example of using Sequential + model = SuperSequential( + nn.Conv2d(1,20,5), + nn.ReLU(), + nn.Conv2d(20,64,5), + nn.ReLU() + ) + # Example of using Sequential with OrderedDict + model = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(1,20,5)), + ('relu1', nn.ReLU()), + ('conv2', nn.Conv2d(20,64,5)), + ('relu2', nn.ReLU()) + ])) + """ + + def __init__(self, *args): + super(SuperSequential, self).__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + if not isinstance(args, (list, tuple)): + raise ValueError("Invalid input type: {:}".format(type(args))) + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: + """Get the idx-th item of the iterator""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError("index {} is out of range".format(idx)) + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx) -> Union["SuperSequential", T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: SuperModule) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + + def __len__(self) -> int: + return len(self._modules) + + def __dir__(self): + keys = super(SuperSequential, self).__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[SuperModule]: + return iter(self._modules.values()) + + @property + def abstract_search_space(self): + root_node = spaces.VirtualNode(id(self)) + for index, module in enumerate(self): + space = module.abstract_search_space + if not spaces.is_determined(space): + root_node.append(str(index), space) + return root_node + + def apply_candidate(self, abstract_child: spaces.VirtualNode): + super(SuperSequential, self).apply_candidate(abstract_child) + for index in range(len(self)): + if str(index) in abstract_child: + self.__getitem__(index).apply_candidate(abstract_child[str(index)]) + + def forward_candidate(self, input): + return self.forward_raw(input) + + def forward_raw(self, input): + for module in self: + input = module(input) + return input diff --git a/lib/xlayers/super_core.py b/lib/xlayers/super_core.py index 36a14a5..ea6bfdb 100644 --- a/lib/xlayers/super_core.py +++ b/lib/xlayers/super_core.py @@ -3,6 +3,7 @@ ##################################################### from .super_module import SuperRunMode from .super_module import SuperModule +from .super_container import SuperSequential from .super_linear import SuperLinear from .super_linear import SuperMLPv1, SuperMLPv2 from .super_norm import SuperLayerNorm1D diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index 713ec2e..8406dad 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -1,6 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### +# pytest tests/test_basic_space.py -s # +##################################################### import sys, random import unittest import pytest diff --git a/tests/test_super_att.py b/tests/test_super_att.py index 3886f8d..c4a0900 100644 --- a/tests/test_super_att.py +++ b/tests/test_super_att.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### -# pytest ./tests/test_super_model.py -s # +# pytest ./tests/test_super_att.py -s # ##################################################### import sys, random import unittest diff --git a/tests/test_super_container.py b/tests/test_super_container.py new file mode 100644 index 0000000..56bb77a --- /dev/null +++ b/tests/test_super_container.py @@ -0,0 +1,68 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest ./tests/test_super_container.py -s # +##################################################### +import sys, random +import unittest +import pytest +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("library path: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +import torch +from xlayers import super_core +import spaces + + +"""Test the super container layers.""" + + +def _internal_func(inputs, model): + outputs = model(inputs) + abstract_space = model.abstract_search_space + print( + "The abstract search space for SuperAttention is:\n{:}".format(abstract_space) + ) + abstract_space.clean_last() + abstract_child = abstract_space.random(reuse_last=True) + print("The abstract child program is:\n{:}".format(abstract_child)) + model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.apply_candidate(abstract_child) + outputs = model(inputs) + return abstract_child, outputs + + +def _create_stel(input_dim, output_dim): + return super_core.SuperTransformerEncoderLayer( + input_dim, + output_dim, + num_heads=spaces.Categorical(2, 4, 6), + mlp_hidden_multiplier=spaces.Categorical(1, 2, 4), + ) + + +@pytest.mark.parametrize("batch", (1, 2, 4)) +@pytest.mark.parametrize("seq_dim", (1, 10, 30)) +@pytest.mark.parametrize("input_dim", (6, 12, 24, 27)) +def test_super_sequential(batch, seq_dim, input_dim): + out1_dim = spaces.Categorical(12, 24, 36) + out2_dim = spaces.Categorical(24, 36, 48) + out3_dim = spaces.Categorical(36, 72, 100) + layer1 = _create_stel(input_dim, out1_dim) + layer2 = _create_stel(out1_dim, out2_dim) + layer3 = _create_stel(out2_dim, out3_dim) + model = super_core.SuperSequential(layer1, layer2, layer3) + print(model) + model.apply_verbose(True) + inputs = torch.rand(batch, seq_dim, input_dim) + abstract_child, outputs = _internal_func(inputs, model) + output_shape = ( + batch, + seq_dim, + out3_dim.abstract(reuse_last=True).random(reuse_last=True).value, + ) + assert tuple(outputs.shape) == output_shape