Update misc
This commit is contained in:
		| @@ -26,3 +26,5 @@ class TestSuperReArrange(unittest.TestCase): | |||||||
|         tensor = torch.rand((8, 4, 32, 32)) |         tensor = torch.rand((8, 4, 32, 32)) | ||||||
|         print("The tensor shape: {:}".format(tensor.shape)) |         print("The tensor shape: {:}".format(tensor.shape)) | ||||||
|         print(layer) |         print(layer) | ||||||
|  |         outs = layer(tensor) | ||||||
|  |         print("The output tensor shape: {:}".format(outs.shape)) | ||||||
|   | |||||||
							
								
								
									
										145
									
								
								xautodl/xlayers/misc_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								xautodl/xlayers/misc_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | |||||||
|  | # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ParsedExpression: | ||||||
|  |     """ | ||||||
|  |     non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') | ||||||
|  |     and keeps some information important for downstream | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, expression): | ||||||
|  |         self.has_ellipsis = False | ||||||
|  |         self.has_ellipsis_parenthesized = None | ||||||
|  |         self.identifiers = set() | ||||||
|  |         # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition | ||||||
|  |         self.has_non_unitary_anonymous_axes = False | ||||||
|  |         # composition keeps structure of composite axes, see how different corner cases are handled in tests | ||||||
|  |         self.composition = [] | ||||||
|  |         if "." in expression: | ||||||
|  |             if "..." not in expression: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Expression may contain dots only inside ellipsis (...)" | ||||||
|  |                 ) | ||||||
|  |             if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " | ||||||
|  |                 ) | ||||||
|  |             expression = expression.replace("...", _ellipsis) | ||||||
|  |             self.has_ellipsis = True | ||||||
|  |  | ||||||
|  |         bracket_group = None | ||||||
|  |  | ||||||
|  |         def add_axis_name(x): | ||||||
|  |             if x is not None: | ||||||
|  |                 if x in self.identifiers: | ||||||
|  |                     raise ValueError( | ||||||
|  |                         'Indexing expression contains duplicate dimension "{}"'.format( | ||||||
|  |                             x | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|  |                 if x == _ellipsis: | ||||||
|  |                     self.identifiers.add(_ellipsis) | ||||||
|  |                     if bracket_group is None: | ||||||
|  |                         self.composition.append(_ellipsis) | ||||||
|  |                         self.has_ellipsis_parenthesized = False | ||||||
|  |                     else: | ||||||
|  |                         bracket_group.append(_ellipsis) | ||||||
|  |                         self.has_ellipsis_parenthesized = True | ||||||
|  |                 else: | ||||||
|  |                     is_number = str.isdecimal(x) | ||||||
|  |                     if is_number and int(x) == 1: | ||||||
|  |                         # handling the case of anonymous axis of length 1 | ||||||
|  |                         if bracket_group is None: | ||||||
|  |                             self.composition.append([]) | ||||||
|  |                         else: | ||||||
|  |                             pass  # no need to think about 1s inside parenthesis | ||||||
|  |                         return | ||||||
|  |                     is_axis_name, reason = self.check_axis_name(x, return_reason=True) | ||||||
|  |                     if not (is_number or is_axis_name): | ||||||
|  |                         raise ValueError( | ||||||
|  |                             "Invalid axis identifier: {}\n{}".format(x, reason) | ||||||
|  |                         ) | ||||||
|  |                     if is_number: | ||||||
|  |                         x = AnonymousAxis(x) | ||||||
|  |                     self.identifiers.add(x) | ||||||
|  |                     if is_number: | ||||||
|  |                         self.has_non_unitary_anonymous_axes = True | ||||||
|  |                     if bracket_group is None: | ||||||
|  |                         self.composition.append([x]) | ||||||
|  |                     else: | ||||||
|  |                         bracket_group.append(x) | ||||||
|  |  | ||||||
|  |         current_identifier = None | ||||||
|  |         for char in expression: | ||||||
|  |             if char in "() ": | ||||||
|  |                 add_axis_name(current_identifier) | ||||||
|  |                 current_identifier = None | ||||||
|  |                 if char == "(": | ||||||
|  |                     if bracket_group is not None: | ||||||
|  |                         raise ValueError( | ||||||
|  |                             "Axis composition is one-level (brackets inside brackets not allowed)" | ||||||
|  |                         ) | ||||||
|  |                     bracket_group = [] | ||||||
|  |                 elif char == ")": | ||||||
|  |                     if bracket_group is None: | ||||||
|  |                         raise ValueError("Brackets are not balanced") | ||||||
|  |                     self.composition.append(bracket_group) | ||||||
|  |                     bracket_group = None | ||||||
|  |             elif str.isalnum(char) or char in ["_", _ellipsis]: | ||||||
|  |                 if current_identifier is None: | ||||||
|  |                     current_identifier = char | ||||||
|  |                 else: | ||||||
|  |                     current_identifier += char | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("Unknown character '{}'".format(char)) | ||||||
|  |  | ||||||
|  |         if bracket_group is not None: | ||||||
|  |             raise ValueError( | ||||||
|  |                 'Imbalanced parentheses in expression: "{}"'.format(expression) | ||||||
|  |             ) | ||||||
|  |         add_axis_name(current_identifier) | ||||||
|  |  | ||||||
|  |     def flat_axes_order(self) -> List: | ||||||
|  |         result = [] | ||||||
|  |         for composed_axis in self.composition: | ||||||
|  |             assert isinstance(composed_axis, list), "does not work with ellipsis" | ||||||
|  |             for axis in composed_axis: | ||||||
|  |                 result.append(axis) | ||||||
|  |         return result | ||||||
|  |  | ||||||
|  |     def has_composed_axes(self) -> bool: | ||||||
|  |         # this will ignore 1 inside brackets | ||||||
|  |         for axes in self.composition: | ||||||
|  |             if isinstance(axes, list) and len(axes) > 1: | ||||||
|  |                 return True | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def check_axis_name(name: str, return_reason=False): | ||||||
|  |         """ | ||||||
|  |         Valid axes names are python identifiers except keywords, | ||||||
|  |         and additionally should not start or end with underscore | ||||||
|  |         """ | ||||||
|  |         if not str.isidentifier(name): | ||||||
|  |             result = False, "not a valid python identifier" | ||||||
|  |         elif name[0] == "_" or name[-1] == "_": | ||||||
|  |             result = False, "axis name should should not start or end with underscore" | ||||||
|  |         else: | ||||||
|  |             if keyword.iskeyword(name): | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "It is discouraged to use axes names that are keywords: {}".format( | ||||||
|  |                         name | ||||||
|  |                     ), | ||||||
|  |                     RuntimeWarning, | ||||||
|  |                 ) | ||||||
|  |             if name in ["axis"]: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "It is discouraged to use 'axis' as an axis name " | ||||||
|  |                     "and will raise an error in future", | ||||||
|  |                     FutureWarning, | ||||||
|  |                 ) | ||||||
|  |             result = True, None | ||||||
|  |         if return_reason: | ||||||
|  |             return result | ||||||
|  |         else: | ||||||
|  |             return result[0] | ||||||
| @@ -11,6 +11,7 @@ import math | |||||||
| from typing import Optional, Callable | from typing import Optional, Callable | ||||||
|  |  | ||||||
| from xautodl import spaces | from xautodl import spaces | ||||||
|  | from .misc_utils import ParsedExpression | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_module import IntSpaceType | from .super_module import IntSpaceType | ||||||
| from .super_module import BoolSpaceType | from .super_module import BoolSpaceType | ||||||
| @@ -24,6 +25,17 @@ class SuperReArrange(SuperModule): | |||||||
|  |  | ||||||
|         self._pattern = pattern |         self._pattern = pattern | ||||||
|         self._axes_lengths = axes_lengths |         self._axes_lengths = axes_lengths | ||||||
|  |         axes_lengths = tuple(sorted(self._axes_lengths.items())) | ||||||
|  |         # Perform initial parsing of pattern and provided supplementary info | ||||||
|  |         # axes_lengths is a tuple of tuples (axis_name, axis_length) | ||||||
|  |         left, right = pattern.split("->") | ||||||
|  |         left = ParsedExpression(left) | ||||||
|  |         right = ParsedExpression(right) | ||||||
|  |  | ||||||
|  |         import pdb | ||||||
|  |  | ||||||
|  |         pdb.set_trace() | ||||||
|  |         print("-") | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
| @@ -31,13 +43,16 @@ class SuperReArrange(SuperModule): | |||||||
|         return root_node |         return root_node | ||||||
|  |  | ||||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         raise NotImplementedError |         self.forward_raw(input) | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         import pdb | ||||||
|  |  | ||||||
|  |         pdb.set_trace() | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         params = repr(self._pattern) |         params = repr(self._pattern) | ||||||
|         for axis, length in self._axes_lengths.items(): |         for axis, length in self._axes_lengths.items(): | ||||||
|             params += ", {}={}".format(axis, length) |             params += ", {}={}".format(axis, length) | ||||||
|         return "{}({})".format(self.__class__.__name__, params) |         return "{:}".format(params) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user