Add SuperSequential

This commit is contained in:
D-X-Y 2021-03-21 13:26:52 +08:00
parent 32900797eb
commit 033878becb
7 changed files with 189 additions and 2 deletions

View File

@ -29,5 +29,5 @@ jobs:
python -m pip install pytest numpy
python -m pip install parameterized
python -m pip install torch torchvision torchaudio
python -m pytest ./tests/test_super_model.py -s
python -m pytest ./tests/test_super_*.py -s
shell: bash

View File

@ -16,3 +16,8 @@ python -m black __init__.py -l 120
pytest -W ignore::DeprecationWarning qlib/tests/test_all_pipeline.py
```
```
conda update --all
```

View File

@ -0,0 +1,111 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
from itertools import islice
import operator
from collections import OrderedDict
from typing import Optional, Union, Callable, TypeVar, Iterator
import spaces
from .super_module import SuperModule
T = TypeVar("T", bound=SuperModule)
class SuperSequential(SuperModule):
"""A sequential container wrapped with 'Super' ability.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example::
# Example of using Sequential
model = SuperSequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
def __init__(self, *args):
super(SuperSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
if not isinstance(args, (list, tuple)):
raise ValueError("Invalid input type: {:}".format(type(args)))
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError("index {} is out of range".format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx) -> Union["SuperSequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: SuperModule) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self) -> int:
return len(self._modules)
def __dir__(self):
keys = super(SuperSequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self) -> Iterator[SuperModule]:
return iter(self._modules.values())
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
for index, module in enumerate(self):
space = module.abstract_search_space
if not spaces.is_determined(space):
root_node.append(str(index), space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSequential, self).apply_candidate(abstract_child)
for index in range(len(self)):
if str(index) in abstract_child:
self.__getitem__(index).apply_candidate(abstract_child[str(index)])
def forward_candidate(self, input):
return self.forward_raw(input)
def forward_raw(self, input):
for module in self:
input = module(input)
return input

View File

@ -3,6 +3,7 @@
#####################################################
from .super_module import SuperRunMode
from .super_module import SuperModule
from .super_container import SuperSequential
from .super_linear import SuperLinear
from .super_linear import SuperMLPv1, SuperMLPv2
from .super_norm import SuperLayerNorm1D

View File

@ -1,6 +1,8 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_basic_space.py -s #
#####################################################
import sys, random
import unittest
import pytest

View File

@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest ./tests/test_super_model.py -s #
# pytest ./tests/test_super_att.py -s #
#####################################################
import sys, random
import unittest

View File

@ -0,0 +1,68 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest ./tests/test_super_container.py -s #
#####################################################
import sys, random
import unittest
import pytest
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import torch
from xlayers import super_core
import spaces
"""Test the super container layers."""
def _internal_func(inputs, model):
outputs = model(inputs)
abstract_space = model.abstract_search_space
print(
"The abstract search space for SuperAttention is:\n{:}".format(abstract_space)
)
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candidate(abstract_child)
outputs = model(inputs)
return abstract_child, outputs
def _create_stel(input_dim, output_dim):
return super_core.SuperTransformerEncoderLayer(
input_dim,
output_dim,
num_heads=spaces.Categorical(2, 4, 6),
mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),
)
@pytest.mark.parametrize("batch", (1, 2, 4))
@pytest.mark.parametrize("seq_dim", (1, 10, 30))
@pytest.mark.parametrize("input_dim", (6, 12, 24, 27))
def test_super_sequential(batch, seq_dim, input_dim):
out1_dim = spaces.Categorical(12, 24, 36)
out2_dim = spaces.Categorical(24, 36, 48)
out3_dim = spaces.Categorical(36, 72, 100)
layer1 = _create_stel(input_dim, out1_dim)
layer2 = _create_stel(out1_dim, out2_dim)
layer3 = _create_stel(out2_dim, out3_dim)
model = super_core.SuperSequential(layer1, layer2, layer3)
print(model)
model.apply_verbose(True)
inputs = torch.rand(batch, seq_dim, input_dim)
abstract_child, outputs = _internal_func(inputs, model)
output_shape = (
batch,
seq_dim,
out3_dim.abstract(reuse_last=True).random(reuse_last=True).value,
)
assert tuple(outputs.shape) == output_shape