diff --git a/tests/test_super_rearrange.py b/tests/test_super_rearrange.py index eabf0fe..bd2d35b 100644 --- a/tests/test_super_rearrange.py +++ b/tests/test_super_rearrange.py @@ -26,3 +26,5 @@ class TestSuperReArrange(unittest.TestCase): tensor = torch.rand((8, 4, 32, 32)) print("The tensor shape: {:}".format(tensor.shape)) print(layer) + outs = layer(tensor) + print("The output tensor shape: {:}".format(outs.shape)) diff --git a/xautodl/xlayers/misc_utils.py b/xautodl/xlayers/misc_utils.py new file mode 100644 index 0000000..564cd7f --- /dev/null +++ b/xautodl/xlayers/misc_utils.py @@ -0,0 +1,145 @@ +# borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py + + +class ParsedExpression: + """ + non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') + and keeps some information important for downstream + """ + + def __init__(self, expression): + self.has_ellipsis = False + self.has_ellipsis_parenthesized = None + self.identifiers = set() + # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition = [] + if "." in expression: + if "..." not in expression: + raise ValueError( + "Expression may contain dots only inside ellipsis (...)" + ) + if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: + raise ValueError( + "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " + ) + expression = expression.replace("...", _ellipsis) + self.has_ellipsis = True + + bracket_group = None + + def add_axis_name(x): + if x is not None: + if x in self.identifiers: + raise ValueError( + 'Indexing expression contains duplicate dimension "{}"'.format( + x + ) + ) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name(x, return_reason=True) + if not (is_number or is_axis_name): + raise ValueError( + "Invalid axis identifier: {}\n{}".format(x, reason) + ) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) + + current_identifier = None + for char in expression: + if char in "() ": + add_axis_name(current_identifier) + current_identifier = None + if char == "(": + if bracket_group is not None: + raise ValueError( + "Axis composition is one-level (brackets inside brackets not allowed)" + ) + bracket_group = [] + elif char == ")": + if bracket_group is None: + raise ValueError("Brackets are not balanced") + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ["_", _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise ValueError("Unknown character '{}'".format(char)) + + if bracket_group is not None: + raise ValueError( + 'Imbalanced parentheses in expression: "{}"'.format(expression) + ) + add_axis_name(current_identifier) + + def flat_axes_order(self) -> List: + result = [] + for composed_axis in self.composition: + assert isinstance(composed_axis, list), "does not work with ellipsis" + for axis in composed_axis: + result.append(axis) + return result + + def has_composed_axes(self) -> bool: + # this will ignore 1 inside brackets + for axes in self.composition: + if isinstance(axes, list) and len(axes) > 1: + return True + return False + + @staticmethod + def check_axis_name(name: str, return_reason=False): + """ + Valid axes names are python identifiers except keywords, + and additionally should not start or end with underscore + """ + if not str.isidentifier(name): + result = False, "not a valid python identifier" + elif name[0] == "_" or name[-1] == "_": + result = False, "axis name should should not start or end with underscore" + else: + if keyword.iskeyword(name): + warnings.warn( + "It is discouraged to use axes names that are keywords: {}".format( + name + ), + RuntimeWarning, + ) + if name in ["axis"]: + warnings.warn( + "It is discouraged to use 'axis' as an axis name " + "and will raise an error in future", + FutureWarning, + ) + result = True, None + if return_reason: + return result + else: + return result[0] diff --git a/xautodl/xlayers/super_rearrange.py b/xautodl/xlayers/super_rearrange.py index 8f7da5a..76e34e7 100644 --- a/xautodl/xlayers/super_rearrange.py +++ b/xautodl/xlayers/super_rearrange.py @@ -11,6 +11,7 @@ import math from typing import Optional, Callable from xautodl import spaces +from .misc_utils import ParsedExpression from .super_module import SuperModule from .super_module import IntSpaceType from .super_module import BoolSpaceType @@ -24,6 +25,17 @@ class SuperReArrange(SuperModule): self._pattern = pattern self._axes_lengths = axes_lengths + axes_lengths = tuple(sorted(self._axes_lengths.items())) + # Perform initial parsing of pattern and provided supplementary info + # axes_lengths is a tuple of tuples (axis_name, axis_length) + left, right = pattern.split("->") + left = ParsedExpression(left) + right = ParsedExpression(right) + + import pdb + + pdb.set_trace() + print("-") @property def abstract_search_space(self): @@ -31,13 +43,16 @@ class SuperReArrange(SuperModule): return root_node def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: - raise NotImplementedError + self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: + import pdb + + pdb.set_trace() raise NotImplementedError def extra_repr(self) -> str: params = repr(self._pattern) for axis, length in self._axes_lengths.items(): params += ", {}={}".format(axis, length) - return "{}({})".format(self.__class__.__name__, params) + return "{:}".format(params)