Update misc
This commit is contained in:
		| @@ -26,3 +26,5 @@ class TestSuperReArrange(unittest.TestCase): | ||||
|         tensor = torch.rand((8, 4, 32, 32)) | ||||
|         print("The tensor shape: {:}".format(tensor.shape)) | ||||
|         print(layer) | ||||
|         outs = layer(tensor) | ||||
|         print("The output tensor shape: {:}".format(outs.shape)) | ||||
|   | ||||
							
								
								
									
										145
									
								
								xautodl/xlayers/misc_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										145
									
								
								xautodl/xlayers/misc_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,145 @@ | ||||
| # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py | ||||
|  | ||||
|  | ||||
| class ParsedExpression: | ||||
|     """ | ||||
|     non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') | ||||
|     and keeps some information important for downstream | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, expression): | ||||
|         self.has_ellipsis = False | ||||
|         self.has_ellipsis_parenthesized = None | ||||
|         self.identifiers = set() | ||||
|         # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition | ||||
|         self.has_non_unitary_anonymous_axes = False | ||||
|         # composition keeps structure of composite axes, see how different corner cases are handled in tests | ||||
|         self.composition = [] | ||||
|         if "." in expression: | ||||
|             if "..." not in expression: | ||||
|                 raise ValueError( | ||||
|                     "Expression may contain dots only inside ellipsis (...)" | ||||
|                 ) | ||||
|             if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: | ||||
|                 raise ValueError( | ||||
|                     "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " | ||||
|                 ) | ||||
|             expression = expression.replace("...", _ellipsis) | ||||
|             self.has_ellipsis = True | ||||
|  | ||||
|         bracket_group = None | ||||
|  | ||||
|         def add_axis_name(x): | ||||
|             if x is not None: | ||||
|                 if x in self.identifiers: | ||||
|                     raise ValueError( | ||||
|                         'Indexing expression contains duplicate dimension "{}"'.format( | ||||
|                             x | ||||
|                         ) | ||||
|                     ) | ||||
|                 if x == _ellipsis: | ||||
|                     self.identifiers.add(_ellipsis) | ||||
|                     if bracket_group is None: | ||||
|                         self.composition.append(_ellipsis) | ||||
|                         self.has_ellipsis_parenthesized = False | ||||
|                     else: | ||||
|                         bracket_group.append(_ellipsis) | ||||
|                         self.has_ellipsis_parenthesized = True | ||||
|                 else: | ||||
|                     is_number = str.isdecimal(x) | ||||
|                     if is_number and int(x) == 1: | ||||
|                         # handling the case of anonymous axis of length 1 | ||||
|                         if bracket_group is None: | ||||
|                             self.composition.append([]) | ||||
|                         else: | ||||
|                             pass  # no need to think about 1s inside parenthesis | ||||
|                         return | ||||
|                     is_axis_name, reason = self.check_axis_name(x, return_reason=True) | ||||
|                     if not (is_number or is_axis_name): | ||||
|                         raise ValueError( | ||||
|                             "Invalid axis identifier: {}\n{}".format(x, reason) | ||||
|                         ) | ||||
|                     if is_number: | ||||
|                         x = AnonymousAxis(x) | ||||
|                     self.identifiers.add(x) | ||||
|                     if is_number: | ||||
|                         self.has_non_unitary_anonymous_axes = True | ||||
|                     if bracket_group is None: | ||||
|                         self.composition.append([x]) | ||||
|                     else: | ||||
|                         bracket_group.append(x) | ||||
|  | ||||
|         current_identifier = None | ||||
|         for char in expression: | ||||
|             if char in "() ": | ||||
|                 add_axis_name(current_identifier) | ||||
|                 current_identifier = None | ||||
|                 if char == "(": | ||||
|                     if bracket_group is not None: | ||||
|                         raise ValueError( | ||||
|                             "Axis composition is one-level (brackets inside brackets not allowed)" | ||||
|                         ) | ||||
|                     bracket_group = [] | ||||
|                 elif char == ")": | ||||
|                     if bracket_group is None: | ||||
|                         raise ValueError("Brackets are not balanced") | ||||
|                     self.composition.append(bracket_group) | ||||
|                     bracket_group = None | ||||
|             elif str.isalnum(char) or char in ["_", _ellipsis]: | ||||
|                 if current_identifier is None: | ||||
|                     current_identifier = char | ||||
|                 else: | ||||
|                     current_identifier += char | ||||
|             else: | ||||
|                 raise ValueError("Unknown character '{}'".format(char)) | ||||
|  | ||||
|         if bracket_group is not None: | ||||
|             raise ValueError( | ||||
|                 'Imbalanced parentheses in expression: "{}"'.format(expression) | ||||
|             ) | ||||
|         add_axis_name(current_identifier) | ||||
|  | ||||
|     def flat_axes_order(self) -> List: | ||||
|         result = [] | ||||
|         for composed_axis in self.composition: | ||||
|             assert isinstance(composed_axis, list), "does not work with ellipsis" | ||||
|             for axis in composed_axis: | ||||
|                 result.append(axis) | ||||
|         return result | ||||
|  | ||||
|     def has_composed_axes(self) -> bool: | ||||
|         # this will ignore 1 inside brackets | ||||
|         for axes in self.composition: | ||||
|             if isinstance(axes, list) and len(axes) > 1: | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     @staticmethod | ||||
|     def check_axis_name(name: str, return_reason=False): | ||||
|         """ | ||||
|         Valid axes names are python identifiers except keywords, | ||||
|         and additionally should not start or end with underscore | ||||
|         """ | ||||
|         if not str.isidentifier(name): | ||||
|             result = False, "not a valid python identifier" | ||||
|         elif name[0] == "_" or name[-1] == "_": | ||||
|             result = False, "axis name should should not start or end with underscore" | ||||
|         else: | ||||
|             if keyword.iskeyword(name): | ||||
|                 warnings.warn( | ||||
|                     "It is discouraged to use axes names that are keywords: {}".format( | ||||
|                         name | ||||
|                     ), | ||||
|                     RuntimeWarning, | ||||
|                 ) | ||||
|             if name in ["axis"]: | ||||
|                 warnings.warn( | ||||
|                     "It is discouraged to use 'axis' as an axis name " | ||||
|                     "and will raise an error in future", | ||||
|                     FutureWarning, | ||||
|                 ) | ||||
|             result = True, None | ||||
|         if return_reason: | ||||
|             return result | ||||
|         else: | ||||
|             return result[0] | ||||
| @@ -11,6 +11,7 @@ import math | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| from xautodl import spaces | ||||
| from .misc_utils import ParsedExpression | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
| @@ -24,6 +25,17 @@ class SuperReArrange(SuperModule): | ||||
|  | ||||
|         self._pattern = pattern | ||||
|         self._axes_lengths = axes_lengths | ||||
|         axes_lengths = tuple(sorted(self._axes_lengths.items())) | ||||
|         # Perform initial parsing of pattern and provided supplementary info | ||||
|         # axes_lengths is a tuple of tuples (axis_name, axis_length) | ||||
|         left, right = pattern.split("->") | ||||
|         left = ParsedExpression(left) | ||||
|         right = ParsedExpression(right) | ||||
|  | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         print("-") | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
| @@ -31,13 +43,16 @@ class SuperReArrange(SuperModule): | ||||
|         return root_node | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         raise NotImplementedError | ||||
|         self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         params = repr(self._pattern) | ||||
|         for axis, length in self._axes_lengths.items(): | ||||
|             params += ", {}={}".format(axis, length) | ||||
|         return "{}({})".format(self.__class__.__name__, params) | ||||
|         return "{:}".format(params) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user