Complete xlayers.rearrange
This commit is contained in:
		| @@ -29,6 +29,7 @@ class TestSuperSelfAttention(unittest.TestCase): | |||||||
|         abstract_child = abstract_space.random(reuse_last=True) |         abstract_child = abstract_space.random(reuse_last=True) | ||||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) |         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.enable_candidate() | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         return abstract_child, outputs |         return abstract_child, outputs | ||||||
|   | |||||||
| @@ -25,6 +25,7 @@ def _internal_func(inputs, model): | |||||||
|     abstract_space.clean_last() |     abstract_space.clean_last() | ||||||
|     abstract_child = abstract_space.random(reuse_last=True) |     abstract_child = abstract_space.random(reuse_last=True) | ||||||
|     print("The abstract child program is:\n{:}".format(abstract_child)) |     print("The abstract child program is:\n{:}".format(abstract_child)) | ||||||
|  |     model.enable_candidate() | ||||||
|     model.set_super_run_type(super_core.SuperRunMode.Candidate) |     model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|     model.apply_candidate(abstract_child) |     model.apply_candidate(abstract_child) | ||||||
|     outputs = model(inputs) |     outputs = model(inputs) | ||||||
|   | |||||||
| @@ -37,6 +37,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         print("The abstract child program:\n{:}".format(abstract_child)) |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.enable_candidate() | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|         output_shape = (20, abstract_child["_out_features"].value) |         output_shape = (20, abstract_child["_out_features"].value) | ||||||
| @@ -77,6 +78,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) |         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         mlp.enable_candidate() | ||||||
|         mlp.apply_candidate(abstract_child) |         mlp.apply_candidate(abstract_child) | ||||||
|         outputs = mlp(inputs) |         outputs = mlp(inputs) | ||||||
|         output_shape = (4, abstract_child["fc2"]["_out_features"].value) |         output_shape = (4, abstract_child["fc2"]["_out_features"].value) | ||||||
| @@ -103,6 +105,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) |         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) |         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         mlp.enable_candidate() | ||||||
|         mlp.apply_candidate(abstract_child) |         mlp.apply_candidate(abstract_child) | ||||||
|         outputs = mlp(inputs) |         outputs = mlp(inputs) | ||||||
|         output_shape = (4, abstract_child["_out_features"].value) |         output_shape = (4, abstract_child["_out_features"].value) | ||||||
| @@ -120,6 +123,7 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         print("The abstract child program:\n{:}".format(abstract_child)) |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.enable_candidate() | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         output_shape = (4, 60, abstract_child["_embed_dim"].value) |         output_shape = (4, 60, abstract_child["_embed_dim"].value) | ||||||
|   | |||||||
| @@ -38,6 +38,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | |||||||
|         print("The abstract child program:\n{:}".format(abstract_child)) |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.enable_candidate() | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|         output_shape = (20, abstract_child["1"]["_out_features"].value) |         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||||
| @@ -70,6 +71,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | |||||||
|         print("The abstract child program:\n{:}".format(abstract_child)) |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.enable_candidate() | ||||||
|         model.apply_candidate(abstract_child) |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|         output_shape = (20, abstract_child["2"]["_out_features"].value) |         output_shape = (20, abstract_child["2"]["_out_features"].value) | ||||||
|   | |||||||
| @@ -5,12 +5,6 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import sys | import sys | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / "..").resolve() |  | ||||||
| print("LIB-DIR: {:}".format(lib_dir)) |  | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from xautodl import xlayers | from xautodl import xlayers | ||||||
| @@ -28,3 +22,4 @@ class TestSuperReArrange(unittest.TestCase): | |||||||
|         print(layer) |         print(layer) | ||||||
|         outs = layer(tensor) |         outs = layer(tensor) | ||||||
|         print("The output tensor shape: {:}".format(outs.shape)) |         print("The output tensor shape: {:}".format(outs.shape)) | ||||||
|  |         assert tuple(outs.shape) == (8, 32 * 32 // 16, 4 * 4 * 4) | ||||||
|   | |||||||
| @@ -1,36 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### |  | ||||||
| # pytest ./tests/test_super_model.py -s             # |  | ||||||
| ##################################################### |  | ||||||
| import unittest |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from xautodl.xlayers.super_core import SuperRunMode |  | ||||||
| from xautodl.trade_models import get_transformer |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSuperTransformer(unittest.TestCase): |  | ||||||
|     """Test the super transformer.""" |  | ||||||
|  |  | ||||||
|     def test_super_transformer(self): |  | ||||||
|         model = get_transformer(None) |  | ||||||
|         model.apply_verbose(False) |  | ||||||
|         print(model) |  | ||||||
|  |  | ||||||
|         inputs = torch.rand(10, 360) |  | ||||||
|         print("Input shape: {:}".format(inputs.shape)) |  | ||||||
|         outputs = model(inputs) |  | ||||||
|         self.assertEqual(tuple(outputs.shape), (10,)) |  | ||||||
|  |  | ||||||
|         abstract_space = model.abstract_search_space |  | ||||||
|         abstract_space.clean_last() |  | ||||||
|         abstract_child = abstract_space.random(reuse_last=True) |  | ||||||
|         print("The abstract searc space:\n{:}".format(abstract_space)) |  | ||||||
|         print("The abstract child program:\n{:}".format(abstract_child)) |  | ||||||
|  |  | ||||||
|         model.set_super_run_type(SuperRunMode.Candidate) |  | ||||||
|         model.apply_candidate(abstract_child) |  | ||||||
|  |  | ||||||
|         outputs = model(inputs) |  | ||||||
|         self.assertEqual(tuple(outputs.shape), (10,)) |  | ||||||
| @@ -1,4 +1,28 @@ | |||||||
| # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py | # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py | ||||||
|  | import warnings | ||||||
|  | import keyword | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AnonymousAxis: | ||||||
|  |     """Important thing: all instances of this class are not equal to each other""" | ||||||
|  |  | ||||||
|  |     def __init__(self, value: str): | ||||||
|  |         self.value = int(value) | ||||||
|  |         if self.value <= 1: | ||||||
|  |             if self.value == 1: | ||||||
|  |                 raise EinopsError( | ||||||
|  |                     "No need to create anonymous axis of length 1. Report this as an issue" | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise EinopsError( | ||||||
|  |                     "Anonymous axis should have positive length, not {}".format( | ||||||
|  |                         self.value | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{}-axis".format(str(self.value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ParsedExpression: | class ParsedExpression: | ||||||
| @@ -8,24 +32,13 @@ class ParsedExpression: | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, expression): |     def __init__(self, expression): | ||||||
|         self.has_ellipsis = False |  | ||||||
|         self.has_ellipsis_parenthesized = None |  | ||||||
|         self.identifiers = set() |         self.identifiers = set() | ||||||
|         # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition |         # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition | ||||||
|         self.has_non_unitary_anonymous_axes = False |         self.has_non_unitary_anonymous_axes = False | ||||||
|         # composition keeps structure of composite axes, see how different corner cases are handled in tests |         # composition keeps structure of composite axes, see how different corner cases are handled in tests | ||||||
|         self.composition = [] |         self.composition = [] | ||||||
|         if "." in expression: |         if "." in expression: | ||||||
|             if "..." not in expression: |             raise ValueError("Does not support . in the expression.") | ||||||
|                 raise ValueError( |  | ||||||
|                     "Expression may contain dots only inside ellipsis (...)" |  | ||||||
|                 ) |  | ||||||
|             if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: |  | ||||||
|                 raise ValueError( |  | ||||||
|                     "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " |  | ||||||
|                 ) |  | ||||||
|             expression = expression.replace("...", _ellipsis) |  | ||||||
|             self.has_ellipsis = True |  | ||||||
|  |  | ||||||
|         bracket_group = None |         bracket_group = None | ||||||
|  |  | ||||||
| @@ -37,37 +50,28 @@ class ParsedExpression: | |||||||
|                             x |                             x | ||||||
|                         ) |                         ) | ||||||
|                     ) |                     ) | ||||||
|                 if x == _ellipsis: |                 is_number = str.isdecimal(x) | ||||||
|                     self.identifiers.add(_ellipsis) |                 if is_number and int(x) == 1: | ||||||
|  |                     # handling the case of anonymous axis of length 1 | ||||||
|                     if bracket_group is None: |                     if bracket_group is None: | ||||||
|                         self.composition.append(_ellipsis) |                         self.composition.append([]) | ||||||
|                         self.has_ellipsis_parenthesized = False |  | ||||||
|                     else: |                     else: | ||||||
|                         bracket_group.append(_ellipsis) |                         pass  # no need to think about 1s inside parenthesis | ||||||
|                         self.has_ellipsis_parenthesized = True |                     return | ||||||
|  |                 is_axis_name, reason = self.check_axis_name(x, return_reason=True) | ||||||
|  |                 if not (is_number or is_axis_name): | ||||||
|  |                     raise ValueError( | ||||||
|  |                         "Invalid axis identifier: {}\n{}".format(x, reason) | ||||||
|  |                     ) | ||||||
|  |                 if is_number: | ||||||
|  |                     x = AnonymousAxis(x) | ||||||
|  |                 self.identifiers.add(x) | ||||||
|  |                 if is_number: | ||||||
|  |                     self.has_non_unitary_anonymous_axes = True | ||||||
|  |                 if bracket_group is None: | ||||||
|  |                     self.composition.append([x]) | ||||||
|                 else: |                 else: | ||||||
|                     is_number = str.isdecimal(x) |                     bracket_group.append(x) | ||||||
|                     if is_number and int(x) == 1: |  | ||||||
|                         # handling the case of anonymous axis of length 1 |  | ||||||
|                         if bracket_group is None: |  | ||||||
|                             self.composition.append([]) |  | ||||||
|                         else: |  | ||||||
|                             pass  # no need to think about 1s inside parenthesis |  | ||||||
|                         return |  | ||||||
|                     is_axis_name, reason = self.check_axis_name(x, return_reason=True) |  | ||||||
|                     if not (is_number or is_axis_name): |  | ||||||
|                         raise ValueError( |  | ||||||
|                             "Invalid axis identifier: {}\n{}".format(x, reason) |  | ||||||
|                         ) |  | ||||||
|                     if is_number: |  | ||||||
|                         x = AnonymousAxis(x) |  | ||||||
|                     self.identifiers.add(x) |  | ||||||
|                     if is_number: |  | ||||||
|                         self.has_non_unitary_anonymous_axes = True |  | ||||||
|                     if bracket_group is None: |  | ||||||
|                         self.composition.append([x]) |  | ||||||
|                     else: |  | ||||||
|                         bracket_group.append(x) |  | ||||||
|  |  | ||||||
|         current_identifier = None |         current_identifier = None | ||||||
|         for char in expression: |         for char in expression: | ||||||
| @@ -85,7 +89,7 @@ class ParsedExpression: | |||||||
|                         raise ValueError("Brackets are not balanced") |                         raise ValueError("Brackets are not balanced") | ||||||
|                     self.composition.append(bracket_group) |                     self.composition.append(bracket_group) | ||||||
|                     bracket_group = None |                     bracket_group = None | ||||||
|             elif str.isalnum(char) or char in ["_", _ellipsis]: |             elif str.isalnum(char) or char == "_": | ||||||
|                 if current_identifier is None: |                 if current_identifier is None: | ||||||
|                     current_identifier = char |                     current_identifier = char | ||||||
|                 else: |                 else: | ||||||
| @@ -143,3 +147,8 @@ class ParsedExpression: | |||||||
|             return result |             return result | ||||||
|         else: |         else: | ||||||
|             return result[0] |             return result[0] | ||||||
|  |  | ||||||
|  |     def __repr__(self) -> str: | ||||||
|  |         return "{name}({composition})".format( | ||||||
|  |             name=self.__class__.__name__, composition=self.composition | ||||||
|  |         ) | ||||||
|   | |||||||
| @@ -21,6 +21,8 @@ from .super_utils import ShapeContainer | |||||||
| BEST_DIR_KEY = "best_model_dir" | BEST_DIR_KEY = "best_model_dir" | ||||||
| BEST_NAME_KEY = "best_model_name" | BEST_NAME_KEY = "best_model_name" | ||||||
| BEST_SCORE_KEY = "best_model_score" | BEST_SCORE_KEY = "best_model_score" | ||||||
|  | ENABLE_CANDIDATE = 0 | ||||||
|  | DISABLE_CANDIDATE = 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperModule(abc.ABC, nn.Module): | class SuperModule(abc.ABC, nn.Module): | ||||||
| @@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|         self._abstract_child = None |         self._abstract_child = None | ||||||
|         self._verbose = False |         self._verbose = False | ||||||
|         self._meta_info = {} |         self._meta_info = {} | ||||||
|  |         self._candidate_mode = DISABLE_CANDIDATE | ||||||
|  |  | ||||||
|     def set_super_run_type(self, super_run_type): |     def set_super_run_type(self, super_run_type): | ||||||
|         def _reset_super_run(m): |         def _reset_super_run(m): | ||||||
| @@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|             ) |             ) | ||||||
|         self._abstract_child = abstract_child |         self._abstract_child = abstract_child | ||||||
|  |  | ||||||
|  |     def enable_candidate(self): | ||||||
|  |         def _enable_candidate(m): | ||||||
|  |             if isinstance(m, SuperModule): | ||||||
|  |                 m._candidate_mode = ENABLE_CANDIDATE | ||||||
|  |  | ||||||
|  |         self.apply(_enable_candidate) | ||||||
|  |  | ||||||
|  |     def disable_candidate(self): | ||||||
|  |         def _disable_candidate(m): | ||||||
|  |             if isinstance(m, SuperModule): | ||||||
|  |                 m._candidate_mode = DISABLE_CANDIDATE | ||||||
|  |  | ||||||
|  |         self.apply(_disable_candidate) | ||||||
|  |  | ||||||
|     def get_w_container(self): |     def get_w_container(self): | ||||||
|         container = TensorContainer() |         container = TensorContainer() | ||||||
|         for name, param in self.named_parameters(): |         for name, param in self.named_parameters(): | ||||||
| @@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|         if self.super_run_type == SuperRunMode.FullModel: |         if self.super_run_type == SuperRunMode.FullModel: | ||||||
|             outputs = self.forward_raw(*inputs) |             outputs = self.forward_raw(*inputs) | ||||||
|         elif self.super_run_type == SuperRunMode.Candidate: |         elif self.super_run_type == SuperRunMode.Candidate: | ||||||
|  |             if self._candidate_mode == DISABLE_CANDIDATE: | ||||||
|  |                 raise ValueError("candidate mode is disabled") | ||||||
|             outputs = self.forward_candidate(*inputs) |             outputs = self.forward_candidate(*inputs) | ||||||
|         else: |         else: | ||||||
|             raise ModeError( |             raise ValueError( | ||||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) |                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||||
|             ) |             ) | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|   | |||||||
| @@ -8,10 +8,14 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| import math | import math | ||||||
|  | import numpy as np | ||||||
|  | import itertools | ||||||
|  | import functools | ||||||
|  | from collections import OrderedDict | ||||||
| from typing import Optional, Callable | from typing import Optional, Callable | ||||||
|  |  | ||||||
| from xautodl import spaces | from xautodl import spaces | ||||||
| from .misc_utils import ParsedExpression | from .misc_utils import ParsedExpression, AnonymousAxis | ||||||
| from .super_module import SuperModule | from .super_module import SuperModule | ||||||
| from .super_module import IntSpaceType | from .super_module import IntSpaceType | ||||||
| from .super_module import BoolSpaceType | from .super_module import BoolSpaceType | ||||||
| @@ -31,11 +35,133 @@ class SuperReArrange(SuperModule): | |||||||
|         left, right = pattern.split("->") |         left, right = pattern.split("->") | ||||||
|         left = ParsedExpression(left) |         left = ParsedExpression(left) | ||||||
|         right = ParsedExpression(right) |         right = ParsedExpression(right) | ||||||
|  |         difference = set.symmetric_difference(left.identifiers, right.identifiers) | ||||||
|  |         if difference: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Identifiers only on one side of expression (should be on both): {}".format( | ||||||
|  |                     difference | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         import pdb |         # parsing all dimensions to find out lengths | ||||||
|  |         axis_name2known_length = OrderedDict() | ||||||
|  |         for composite_axis in left.composition: | ||||||
|  |             for axis_name in composite_axis: | ||||||
|  |                 if isinstance(axis_name, AnonymousAxis): | ||||||
|  |                     axis_name2known_length[axis_name] = axis_name.value | ||||||
|  |                 else: | ||||||
|  |                     axis_name2known_length[axis_name] = None | ||||||
|  |         for axis_name in right.identifiers: | ||||||
|  |             if axis_name not in axis_name2known_length: | ||||||
|  |                 if isinstance(axis_name, AnonymousAxis): | ||||||
|  |                     axis_name2known_length[axis_name] = axis_name.value | ||||||
|  |                 else: | ||||||
|  |                     axis_name2known_length[axis_name] = None | ||||||
|  |  | ||||||
|         pdb.set_trace() |         axis_name2position = { | ||||||
|         print("-") |             name: position for position, name in enumerate(axis_name2known_length) | ||||||
|  |         } | ||||||
|  |         for elementary_axis, axis_length in axes_lengths: | ||||||
|  |             if not ParsedExpression.check_axis_name(elementary_axis): | ||||||
|  |                 raise ValueError("Invalid name for an axis", elementary_axis) | ||||||
|  |             if elementary_axis not in axis_name2known_length: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     "Axis {} is not used in transform".format(elementary_axis) | ||||||
|  |                 ) | ||||||
|  |             axis_name2known_length[elementary_axis] = axis_length | ||||||
|  |  | ||||||
|  |         input_composite_axes = [] | ||||||
|  |         # some of shapes will be inferred later - all information is prepared for faster inference | ||||||
|  |         for composite_axis in left.composition: | ||||||
|  |             known = { | ||||||
|  |                 axis | ||||||
|  |                 for axis in composite_axis | ||||||
|  |                 if axis_name2known_length[axis] is not None | ||||||
|  |             } | ||||||
|  |             unknown = { | ||||||
|  |                 axis for axis in composite_axis if axis_name2known_length[axis] is None | ||||||
|  |             } | ||||||
|  |             if len(unknown) > 1: | ||||||
|  |                 raise ValueError("Could not infer sizes for {}".format(unknown)) | ||||||
|  |             assert len(unknown) + len(known) == len(composite_axis) | ||||||
|  |             input_composite_axes.append( | ||||||
|  |                 ( | ||||||
|  |                     [axis_name2position[axis] for axis in known], | ||||||
|  |                     [axis_name2position[axis] for axis in unknown], | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         axis_position_after_reduction = {} | ||||||
|  |         for axis_name in itertools.chain(*left.composition): | ||||||
|  |             if axis_name in right.identifiers: | ||||||
|  |                 axis_position_after_reduction[axis_name] = len( | ||||||
|  |                     axis_position_after_reduction | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         result_axes_grouping = [] | ||||||
|  |         for composite_axis in right.composition: | ||||||
|  |             result_axes_grouping.append( | ||||||
|  |                 [axis_name2position[axis] for axis in composite_axis] | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         ordered_axis_right = list(itertools.chain(*right.composition)) | ||||||
|  |         axes_permutation = tuple( | ||||||
|  |             axis_position_after_reduction[axis] | ||||||
|  |             for axis in ordered_axis_right | ||||||
|  |             if axis in left.identifiers | ||||||
|  |         ) | ||||||
|  |         # | ||||||
|  |         self.input_composite_axes = input_composite_axes | ||||||
|  |         self.output_composite_axes = result_axes_grouping | ||||||
|  |         self.elementary_axes_lengths = list(axis_name2known_length.values()) | ||||||
|  |         self.axes_permutation = axes_permutation | ||||||
|  |  | ||||||
|  |     @functools.lru_cache(maxsize=1024) | ||||||
|  |     def reconstruct_from_shape(self, shape): | ||||||
|  |         if len(shape) != len(self.input_composite_axes): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Expected {} dimensions, got {}".format( | ||||||
|  |                     len(self.input_composite_axes), len(shape) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         axes_lengths = list(self.elementary_axes_lengths) | ||||||
|  |         for input_axis, (known_axes, unknown_axes) in enumerate( | ||||||
|  |             self.input_composite_axes | ||||||
|  |         ): | ||||||
|  |             length = shape[input_axis] | ||||||
|  |             known_product = 1 | ||||||
|  |             for axis in known_axes: | ||||||
|  |                 known_product *= axes_lengths[axis] | ||||||
|  |             if len(unknown_axes) == 0: | ||||||
|  |                 if ( | ||||||
|  |                     isinstance(length, int) | ||||||
|  |                     and isinstance(known_product, int) | ||||||
|  |                     and length != known_product | ||||||
|  |                 ): | ||||||
|  |                     raise ValueError( | ||||||
|  |                         "Shape mismatch, {} != {}".format(length, known_product) | ||||||
|  |                     ) | ||||||
|  |             else: | ||||||
|  |                 if ( | ||||||
|  |                     isinstance(length, int) | ||||||
|  |                     and isinstance(known_product, int) | ||||||
|  |                     and length % known_product != 0 | ||||||
|  |                 ): | ||||||
|  |                     raise ValueError( | ||||||
|  |                         "Shape mismatch, can't divide axis of length {} in chunks of {}".format( | ||||||
|  |                             length, known_product | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |                 (unknown_axis,) = unknown_axes | ||||||
|  |                 axes_lengths[unknown_axis] = length // known_product | ||||||
|  |         # at this point all axes_lengths are computed (either have values or variables, but not Nones) | ||||||
|  |         final_shape = [] | ||||||
|  |         for output_axis, grouping in enumerate(self.output_composite_axes): | ||||||
|  |             lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] | ||||||
|  |             final_shape.append(int(np.prod(lengths))) | ||||||
|  |         axes_reordering = self.axes_permutation | ||||||
|  |         return axes_lengths, axes_reordering, final_shape | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
| @@ -46,10 +172,13 @@ class SuperReArrange(SuperModule): | |||||||
|         self.forward_raw(input) |         self.forward_raw(input) | ||||||
|  |  | ||||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|         import pdb |         init_shape, axes_reordering, final_shape = self.reconstruct_from_shape( | ||||||
|  |             tuple(input.shape) | ||||||
|         pdb.set_trace() |         ) | ||||||
|         raise NotImplementedError |         tensor = torch.reshape(input, init_shape) | ||||||
|  |         tensor = tensor.permute(axes_reordering) | ||||||
|  |         tensor = torch.reshape(tensor, final_shape) | ||||||
|  |         return tensor | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         params = repr(self._pattern) |         params = repr(self._pattern) | ||||||
|   | |||||||
| @@ -1,8 +1,6 @@ | |||||||
| opyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### | ##################################################### | ||||||
| from __future__ import division | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||||
| from __future__ import print_function | ##################################################### | ||||||
|  |  | ||||||
| import math | import math | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Optional, Text, List | from typing import Optional, Text, List | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user