Add SuperSequential
This commit is contained in:
parent
32900797eb
commit
033878becb
2
.github/workflows/super_model_test.yml
vendored
2
.github/workflows/super_model_test.yml
vendored
@ -29,5 +29,5 @@ jobs:
|
|||||||
python -m pip install pytest numpy
|
python -m pip install pytest numpy
|
||||||
python -m pip install parameterized
|
python -m pip install parameterized
|
||||||
python -m pip install torch torchvision torchaudio
|
python -m pip install torch torchvision torchaudio
|
||||||
python -m pytest ./tests/test_super_model.py -s
|
python -m pytest ./tests/test_super_*.py -s
|
||||||
shell: bash
|
shell: bash
|
||||||
|
@ -16,3 +16,8 @@ python -m black __init__.py -l 120
|
|||||||
|
|
||||||
pytest -W ignore::DeprecationWarning qlib/tests/test_all_pipeline.py
|
pytest -W ignore::DeprecationWarning qlib/tests/test_all_pipeline.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
conda update --all
|
||||||
|
```
|
||||||
|
111
lib/xlayers/super_container.py
Normal file
111
lib/xlayers/super_container.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from itertools import islice
|
||||||
|
import operator
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Optional, Union, Callable, TypeVar, Iterator
|
||||||
|
|
||||||
|
import spaces
|
||||||
|
from .super_module import SuperModule
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=SuperModule)
|
||||||
|
|
||||||
|
|
||||||
|
class SuperSequential(SuperModule):
|
||||||
|
"""A sequential container wrapped with 'Super' ability.
|
||||||
|
|
||||||
|
Modules will be added to it in the order they are passed in the constructor.
|
||||||
|
Alternatively, an ordered dict of modules can also be passed in.
|
||||||
|
To make it easier to understand, here is a small example::
|
||||||
|
# Example of using Sequential
|
||||||
|
model = SuperSequential(
|
||||||
|
nn.Conv2d(1,20,5),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(20,64,5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
# Example of using Sequential with OrderedDict
|
||||||
|
model = nn.Sequential(OrderedDict([
|
||||||
|
('conv1', nn.Conv2d(1,20,5)),
|
||||||
|
('relu1', nn.ReLU()),
|
||||||
|
('conv2', nn.Conv2d(20,64,5)),
|
||||||
|
('relu2', nn.ReLU())
|
||||||
|
]))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(SuperSequential, self).__init__()
|
||||||
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
||||||
|
for key, module in args[0].items():
|
||||||
|
self.add_module(key, module)
|
||||||
|
else:
|
||||||
|
if not isinstance(args, (list, tuple)):
|
||||||
|
raise ValueError("Invalid input type: {:}".format(type(args)))
|
||||||
|
for idx, module in enumerate(args):
|
||||||
|
self.add_module(str(idx), module)
|
||||||
|
|
||||||
|
def _get_item_by_idx(self, iterator, idx) -> T:
|
||||||
|
"""Get the idx-th item of the iterator"""
|
||||||
|
size = len(self)
|
||||||
|
idx = operator.index(idx)
|
||||||
|
if not -size <= idx < size:
|
||||||
|
raise IndexError("index {} is out of range".format(idx))
|
||||||
|
idx %= size
|
||||||
|
return next(islice(iterator, idx, None))
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> Union["SuperSequential", T]:
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
||||||
|
else:
|
||||||
|
return self._get_item_by_idx(self._modules.values(), idx)
|
||||||
|
|
||||||
|
def __setitem__(self, idx: int, module: SuperModule) -> None:
|
||||||
|
key: str = self._get_item_by_idx(self._modules.keys(), idx)
|
||||||
|
return setattr(self, key, module)
|
||||||
|
|
||||||
|
def __delitem__(self, idx: Union[slice, int]) -> None:
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
for key in list(self._modules.keys())[idx]:
|
||||||
|
delattr(self, key)
|
||||||
|
else:
|
||||||
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
||||||
|
delattr(self, key)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._modules)
|
||||||
|
|
||||||
|
def __dir__(self):
|
||||||
|
keys = super(SuperSequential, self).__dir__()
|
||||||
|
keys = [key for key in keys if not key.isdigit()]
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[SuperModule]:
|
||||||
|
return iter(self._modules.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def abstract_search_space(self):
|
||||||
|
root_node = spaces.VirtualNode(id(self))
|
||||||
|
for index, module in enumerate(self):
|
||||||
|
space = module.abstract_search_space
|
||||||
|
if not spaces.is_determined(space):
|
||||||
|
root_node.append(str(index), space)
|
||||||
|
return root_node
|
||||||
|
|
||||||
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
|
super(SuperSequential, self).apply_candidate(abstract_child)
|
||||||
|
for index in range(len(self)):
|
||||||
|
if str(index) in abstract_child:
|
||||||
|
self.__getitem__(index).apply_candidate(abstract_child[str(index)])
|
||||||
|
|
||||||
|
def forward_candidate(self, input):
|
||||||
|
return self.forward_raw(input)
|
||||||
|
|
||||||
|
def forward_raw(self, input):
|
||||||
|
for module in self:
|
||||||
|
input = module(input)
|
||||||
|
return input
|
@ -3,6 +3,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
from .super_module import SuperRunMode
|
from .super_module import SuperRunMode
|
||||||
from .super_module import SuperModule
|
from .super_module import SuperModule
|
||||||
|
from .super_container import SuperSequential
|
||||||
from .super_linear import SuperLinear
|
from .super_linear import SuperLinear
|
||||||
from .super_linear import SuperMLPv1, SuperMLPv2
|
from .super_linear import SuperMLPv1, SuperMLPv2
|
||||||
from .super_norm import SuperLayerNorm1D
|
from .super_norm import SuperLayerNorm1D
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
# pytest tests/test_basic_space.py -s #
|
||||||
|
#####################################################
|
||||||
import sys, random
|
import sys, random
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_super_model.py -s #
|
# pytest ./tests/test_super_att.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, random
|
import sys, random
|
||||||
import unittest
|
import unittest
|
||||||
|
68
tests/test_super_container.py
Normal file
68
tests/test_super_container.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
# pytest ./tests/test_super_container.py -s #
|
||||||
|
#####################################################
|
||||||
|
import sys, random
|
||||||
|
import unittest
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||||
|
print("library path: {:}".format(lib_dir))
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from xlayers import super_core
|
||||||
|
import spaces
|
||||||
|
|
||||||
|
|
||||||
|
"""Test the super container layers."""
|
||||||
|
|
||||||
|
|
||||||
|
def _internal_func(inputs, model):
|
||||||
|
outputs = model(inputs)
|
||||||
|
abstract_space = model.abstract_search_space
|
||||||
|
print(
|
||||||
|
"The abstract search space for SuperAttention is:\n{:}".format(abstract_space)
|
||||||
|
)
|
||||||
|
abstract_space.clean_last()
|
||||||
|
abstract_child = abstract_space.random(reuse_last=True)
|
||||||
|
print("The abstract child program is:\n{:}".format(abstract_child))
|
||||||
|
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||||
|
model.apply_candidate(abstract_child)
|
||||||
|
outputs = model(inputs)
|
||||||
|
return abstract_child, outputs
|
||||||
|
|
||||||
|
|
||||||
|
def _create_stel(input_dim, output_dim):
|
||||||
|
return super_core.SuperTransformerEncoderLayer(
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
num_heads=spaces.Categorical(2, 4, 6),
|
||||||
|
mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch", (1, 2, 4))
|
||||||
|
@pytest.mark.parametrize("seq_dim", (1, 10, 30))
|
||||||
|
@pytest.mark.parametrize("input_dim", (6, 12, 24, 27))
|
||||||
|
def test_super_sequential(batch, seq_dim, input_dim):
|
||||||
|
out1_dim = spaces.Categorical(12, 24, 36)
|
||||||
|
out2_dim = spaces.Categorical(24, 36, 48)
|
||||||
|
out3_dim = spaces.Categorical(36, 72, 100)
|
||||||
|
layer1 = _create_stel(input_dim, out1_dim)
|
||||||
|
layer2 = _create_stel(out1_dim, out2_dim)
|
||||||
|
layer3 = _create_stel(out2_dim, out3_dim)
|
||||||
|
model = super_core.SuperSequential(layer1, layer2, layer3)
|
||||||
|
print(model)
|
||||||
|
model.apply_verbose(True)
|
||||||
|
inputs = torch.rand(batch, seq_dim, input_dim)
|
||||||
|
abstract_child, outputs = _internal_func(inputs, model)
|
||||||
|
output_shape = (
|
||||||
|
batch,
|
||||||
|
seq_dim,
|
||||||
|
out3_dim.abstract(reuse_last=True).random(reuse_last=True).value,
|
||||||
|
)
|
||||||
|
assert tuple(outputs.shape) == output_shape
|
Loading…
Reference in New Issue
Block a user