diff --git a/tests/test_super_rearrange.py b/tests/test_super_rearrange.py new file mode 100644 index 0000000..eabf0fe --- /dev/null +++ b/tests/test_super_rearrange.py @@ -0,0 +1,28 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest ./tests/test_super_rearrange.py -s # +##################################################### +import sys +import unittest +from pathlib import Path + +lib_dir = (Path(__file__).parent / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +import torch +from xautodl import xlayers + + +class TestSuperReArrange(unittest.TestCase): + """Test the super re-arrange layer.""" + + def test_super_re_arrange(self): + layer = xlayers.SuperReArrange( + "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=4, p2=4 + ) + tensor = torch.rand((8, 4, 32, 32)) + print("The tensor shape: {:}".format(tensor.shape)) + print(layer) diff --git a/xautodl/xlayers/super_core.py b/xautodl/xlayers/super_core.py index 7c026a6..6dacc48 100644 --- a/xautodl/xlayers/super_core.py +++ b/xautodl/xlayers/super_core.py @@ -47,3 +47,5 @@ super_name2activation = { from .super_trade_stem import SuperAlphaEBDv1 from .super_positional_embedding import SuperDynamicPositionE from .super_positional_embedding import SuperPositionalEncoder + +from .super_rearrange import SuperReArrange diff --git a/xautodl/xlayers/super_rearrange.py b/xautodl/xlayers/super_rearrange.py index 9af818b..8f7da5a 100644 --- a/xautodl/xlayers/super_rearrange.py +++ b/xautodl/xlayers/super_rearrange.py @@ -16,15 +16,14 @@ from .super_module import IntSpaceType from .super_module import BoolSpaceType -class SuperRearrange(SuperModule): +class SuperReArrange(SuperModule): """Applies the rearrange operation.""" def __init__(self, pattern, **axes_lengths): - super(SuperRearrange, self).__init__() + super(SuperReArrange, self).__init__() self._pattern = pattern self._axes_lengths = axes_lengths - self.reset_parameters() @property def abstract_search_space(self):