Complete xlayers.rearrange
This commit is contained in:
		| @@ -29,6 +29,7 @@ class TestSuperSelfAttention(unittest.TestCase): | ||||
|         abstract_child = abstract_space.random(reuse_last=True) | ||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.enable_candidate() | ||||
|         model.apply_candidate(abstract_child) | ||||
|         outputs = model(inputs) | ||||
|         return abstract_child, outputs | ||||
|   | ||||
| @@ -25,6 +25,7 @@ def _internal_func(inputs, model): | ||||
|     abstract_space.clean_last() | ||||
|     abstract_child = abstract_space.random(reuse_last=True) | ||||
|     print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|     model.enable_candidate() | ||||
|     model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|     model.apply_candidate(abstract_child) | ||||
|     outputs = model(inputs) | ||||
|   | ||||
| @@ -37,6 +37,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.enable_candidate() | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         output_shape = (20, abstract_child["_out_features"].value) | ||||
| @@ -77,6 +78,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         ) | ||||
|  | ||||
|         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         mlp.enable_candidate() | ||||
|         mlp.apply_candidate(abstract_child) | ||||
|         outputs = mlp(inputs) | ||||
|         output_shape = (4, abstract_child["fc2"]["_out_features"].value) | ||||
| @@ -103,6 +105,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         mlp.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         mlp.enable_candidate() | ||||
|         mlp.apply_candidate(abstract_child) | ||||
|         outputs = mlp(inputs) | ||||
|         output_shape = (4, abstract_child["_out_features"].value) | ||||
| @@ -120,6 +123,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.enable_candidate() | ||||
|         model.apply_candidate(abstract_child) | ||||
|         outputs = model(inputs) | ||||
|         output_shape = (4, 60, abstract_child["_embed_dim"].value) | ||||
|   | ||||
| @@ -38,6 +38,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.enable_candidate() | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||
| @@ -70,6 +71,7 @@ class TestSuperSimpleNorm(unittest.TestCase): | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.enable_candidate() | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         output_shape = (20, abstract_child["2"]["_out_features"].value) | ||||
|   | ||||
| @@ -5,12 +5,6 @@ | ||||
| ##################################################### | ||||
| import sys | ||||
| import unittest | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| import torch | ||||
| from xautodl import xlayers | ||||
| @@ -28,3 +22,4 @@ class TestSuperReArrange(unittest.TestCase): | ||||
|         print(layer) | ||||
|         outs = layer(tensor) | ||||
|         print("The output tensor shape: {:}".format(outs.shape)) | ||||
|         assert tuple(outs.shape) == (8, 32 * 32 // 16, 4 * 4 * 4) | ||||
|   | ||||
| @@ -1,36 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest ./tests/test_super_model.py -s             # | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
| from xautodl.xlayers.super_core import SuperRunMode | ||||
| from xautodl.trade_models import get_transformer | ||||
|  | ||||
|  | ||||
| class TestSuperTransformer(unittest.TestCase): | ||||
|     """Test the super transformer.""" | ||||
|  | ||||
|     def test_super_transformer(self): | ||||
|         model = get_transformer(None) | ||||
|         model.apply_verbose(False) | ||||
|         print(model) | ||||
|  | ||||
|         inputs = torch.rand(10, 360) | ||||
|         print("Input shape: {:}".format(inputs.shape)) | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (10,)) | ||||
|  | ||||
|         abstract_space = model.abstract_search_space | ||||
|         abstract_space.clean_last() | ||||
|         abstract_child = abstract_space.random(reuse_last=True) | ||||
|         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||
|         print("The abstract child program:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(SuperRunMode.Candidate) | ||||
|         model.apply_candidate(abstract_child) | ||||
|  | ||||
|         outputs = model(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (10,)) | ||||
| @@ -1,4 +1,28 @@ | ||||
| # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py | ||||
| import warnings | ||||
| import keyword | ||||
| from typing import List | ||||
|  | ||||
|  | ||||
| class AnonymousAxis: | ||||
|     """Important thing: all instances of this class are not equal to each other""" | ||||
|  | ||||
|     def __init__(self, value: str): | ||||
|         self.value = int(value) | ||||
|         if self.value <= 1: | ||||
|             if self.value == 1: | ||||
|                 raise EinopsError( | ||||
|                     "No need to create anonymous axis of length 1. Report this as an issue" | ||||
|                 ) | ||||
|             else: | ||||
|                 raise EinopsError( | ||||
|                     "Anonymous axis should have positive length, not {}".format( | ||||
|                         self.value | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{}-axis".format(str(self.value)) | ||||
|  | ||||
|  | ||||
| class ParsedExpression: | ||||
| @@ -8,24 +32,13 @@ class ParsedExpression: | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, expression): | ||||
|         self.has_ellipsis = False | ||||
|         self.has_ellipsis_parenthesized = None | ||||
|         self.identifiers = set() | ||||
|         # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition | ||||
|         self.has_non_unitary_anonymous_axes = False | ||||
|         # composition keeps structure of composite axes, see how different corner cases are handled in tests | ||||
|         self.composition = [] | ||||
|         if "." in expression: | ||||
|             if "..." not in expression: | ||||
|                 raise ValueError( | ||||
|                     "Expression may contain dots only inside ellipsis (...)" | ||||
|                 ) | ||||
|             if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: | ||||
|                 raise ValueError( | ||||
|                     "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " | ||||
|                 ) | ||||
|             expression = expression.replace("...", _ellipsis) | ||||
|             self.has_ellipsis = True | ||||
|             raise ValueError("Does not support . in the expression.") | ||||
|  | ||||
|         bracket_group = None | ||||
|  | ||||
| @@ -37,37 +50,28 @@ class ParsedExpression: | ||||
|                             x | ||||
|                         ) | ||||
|                     ) | ||||
|                 if x == _ellipsis: | ||||
|                     self.identifiers.add(_ellipsis) | ||||
|                 is_number = str.isdecimal(x) | ||||
|                 if is_number and int(x) == 1: | ||||
|                     # handling the case of anonymous axis of length 1 | ||||
|                     if bracket_group is None: | ||||
|                         self.composition.append(_ellipsis) | ||||
|                         self.has_ellipsis_parenthesized = False | ||||
|                         self.composition.append([]) | ||||
|                     else: | ||||
|                         bracket_group.append(_ellipsis) | ||||
|                         self.has_ellipsis_parenthesized = True | ||||
|                         pass  # no need to think about 1s inside parenthesis | ||||
|                     return | ||||
|                 is_axis_name, reason = self.check_axis_name(x, return_reason=True) | ||||
|                 if not (is_number or is_axis_name): | ||||
|                     raise ValueError( | ||||
|                         "Invalid axis identifier: {}\n{}".format(x, reason) | ||||
|                     ) | ||||
|                 if is_number: | ||||
|                     x = AnonymousAxis(x) | ||||
|                 self.identifiers.add(x) | ||||
|                 if is_number: | ||||
|                     self.has_non_unitary_anonymous_axes = True | ||||
|                 if bracket_group is None: | ||||
|                     self.composition.append([x]) | ||||
|                 else: | ||||
|                     is_number = str.isdecimal(x) | ||||
|                     if is_number and int(x) == 1: | ||||
|                         # handling the case of anonymous axis of length 1 | ||||
|                         if bracket_group is None: | ||||
|                             self.composition.append([]) | ||||
|                         else: | ||||
|                             pass  # no need to think about 1s inside parenthesis | ||||
|                         return | ||||
|                     is_axis_name, reason = self.check_axis_name(x, return_reason=True) | ||||
|                     if not (is_number or is_axis_name): | ||||
|                         raise ValueError( | ||||
|                             "Invalid axis identifier: {}\n{}".format(x, reason) | ||||
|                         ) | ||||
|                     if is_number: | ||||
|                         x = AnonymousAxis(x) | ||||
|                     self.identifiers.add(x) | ||||
|                     if is_number: | ||||
|                         self.has_non_unitary_anonymous_axes = True | ||||
|                     if bracket_group is None: | ||||
|                         self.composition.append([x]) | ||||
|                     else: | ||||
|                         bracket_group.append(x) | ||||
|                     bracket_group.append(x) | ||||
|  | ||||
|         current_identifier = None | ||||
|         for char in expression: | ||||
| @@ -85,7 +89,7 @@ class ParsedExpression: | ||||
|                         raise ValueError("Brackets are not balanced") | ||||
|                     self.composition.append(bracket_group) | ||||
|                     bracket_group = None | ||||
|             elif str.isalnum(char) or char in ["_", _ellipsis]: | ||||
|             elif str.isalnum(char) or char == "_": | ||||
|                 if current_identifier is None: | ||||
|                     current_identifier = char | ||||
|                 else: | ||||
| @@ -143,3 +147,8 @@ class ParsedExpression: | ||||
|             return result | ||||
|         else: | ||||
|             return result[0] | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return "{name}({composition})".format( | ||||
|             name=self.__class__.__name__, composition=self.composition | ||||
|         ) | ||||
|   | ||||
| @@ -21,6 +21,8 @@ from .super_utils import ShapeContainer | ||||
| BEST_DIR_KEY = "best_model_dir" | ||||
| BEST_NAME_KEY = "best_model_name" | ||||
| BEST_SCORE_KEY = "best_model_score" | ||||
| ENABLE_CANDIDATE = 0 | ||||
| DISABLE_CANDIDATE = 1 | ||||
|  | ||||
|  | ||||
| class SuperModule(abc.ABC, nn.Module): | ||||
| @@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|         self._abstract_child = None | ||||
|         self._verbose = False | ||||
|         self._meta_info = {} | ||||
|         self._candidate_mode = DISABLE_CANDIDATE | ||||
|  | ||||
|     def set_super_run_type(self, super_run_type): | ||||
|         def _reset_super_run(m): | ||||
| @@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|             ) | ||||
|         self._abstract_child = abstract_child | ||||
|  | ||||
|     def enable_candidate(self): | ||||
|         def _enable_candidate(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|                 m._candidate_mode = ENABLE_CANDIDATE | ||||
|  | ||||
|         self.apply(_enable_candidate) | ||||
|  | ||||
|     def disable_candidate(self): | ||||
|         def _disable_candidate(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|                 m._candidate_mode = DISABLE_CANDIDATE | ||||
|  | ||||
|         self.apply(_disable_candidate) | ||||
|  | ||||
|     def get_w_container(self): | ||||
|         container = TensorContainer() | ||||
|         for name, param in self.named_parameters(): | ||||
| @@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|         if self.super_run_type == SuperRunMode.FullModel: | ||||
|             outputs = self.forward_raw(*inputs) | ||||
|         elif self.super_run_type == SuperRunMode.Candidate: | ||||
|             if self._candidate_mode == DISABLE_CANDIDATE: | ||||
|                 raise ValueError("candidate mode is disabled") | ||||
|             outputs = self.forward_candidate(*inputs) | ||||
|         else: | ||||
|             raise ModeError( | ||||
|             raise ValueError( | ||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||
|             ) | ||||
|         if self.verbose: | ||||
|   | ||||
| @@ -8,10 +8,14 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| import numpy as np | ||||
| import itertools | ||||
| import functools | ||||
| from collections import OrderedDict | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| from xautodl import spaces | ||||
| from .misc_utils import ParsedExpression | ||||
| from .misc_utils import ParsedExpression, AnonymousAxis | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
| @@ -31,11 +35,133 @@ class SuperReArrange(SuperModule): | ||||
|         left, right = pattern.split("->") | ||||
|         left = ParsedExpression(left) | ||||
|         right = ParsedExpression(right) | ||||
|         difference = set.symmetric_difference(left.identifiers, right.identifiers) | ||||
|         if difference: | ||||
|             raise ValueError( | ||||
|                 "Identifiers only on one side of expression (should be on both): {}".format( | ||||
|                     difference | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         import pdb | ||||
|         # parsing all dimensions to find out lengths | ||||
|         axis_name2known_length = OrderedDict() | ||||
|         for composite_axis in left.composition: | ||||
|             for axis_name in composite_axis: | ||||
|                 if isinstance(axis_name, AnonymousAxis): | ||||
|                     axis_name2known_length[axis_name] = axis_name.value | ||||
|                 else: | ||||
|                     axis_name2known_length[axis_name] = None | ||||
|         for axis_name in right.identifiers: | ||||
|             if axis_name not in axis_name2known_length: | ||||
|                 if isinstance(axis_name, AnonymousAxis): | ||||
|                     axis_name2known_length[axis_name] = axis_name.value | ||||
|                 else: | ||||
|                     axis_name2known_length[axis_name] = None | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         print("-") | ||||
|         axis_name2position = { | ||||
|             name: position for position, name in enumerate(axis_name2known_length) | ||||
|         } | ||||
|         for elementary_axis, axis_length in axes_lengths: | ||||
|             if not ParsedExpression.check_axis_name(elementary_axis): | ||||
|                 raise ValueError("Invalid name for an axis", elementary_axis) | ||||
|             if elementary_axis not in axis_name2known_length: | ||||
|                 raise ValueError( | ||||
|                     "Axis {} is not used in transform".format(elementary_axis) | ||||
|                 ) | ||||
|             axis_name2known_length[elementary_axis] = axis_length | ||||
|  | ||||
|         input_composite_axes = [] | ||||
|         # some of shapes will be inferred later - all information is prepared for faster inference | ||||
|         for composite_axis in left.composition: | ||||
|             known = { | ||||
|                 axis | ||||
|                 for axis in composite_axis | ||||
|                 if axis_name2known_length[axis] is not None | ||||
|             } | ||||
|             unknown = { | ||||
|                 axis for axis in composite_axis if axis_name2known_length[axis] is None | ||||
|             } | ||||
|             if len(unknown) > 1: | ||||
|                 raise ValueError("Could not infer sizes for {}".format(unknown)) | ||||
|             assert len(unknown) + len(known) == len(composite_axis) | ||||
|             input_composite_axes.append( | ||||
|                 ( | ||||
|                     [axis_name2position[axis] for axis in known], | ||||
|                     [axis_name2position[axis] for axis in unknown], | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         axis_position_after_reduction = {} | ||||
|         for axis_name in itertools.chain(*left.composition): | ||||
|             if axis_name in right.identifiers: | ||||
|                 axis_position_after_reduction[axis_name] = len( | ||||
|                     axis_position_after_reduction | ||||
|                 ) | ||||
|  | ||||
|         result_axes_grouping = [] | ||||
|         for composite_axis in right.composition: | ||||
|             result_axes_grouping.append( | ||||
|                 [axis_name2position[axis] for axis in composite_axis] | ||||
|             ) | ||||
|  | ||||
|         ordered_axis_right = list(itertools.chain(*right.composition)) | ||||
|         axes_permutation = tuple( | ||||
|             axis_position_after_reduction[axis] | ||||
|             for axis in ordered_axis_right | ||||
|             if axis in left.identifiers | ||||
|         ) | ||||
|         # | ||||
|         self.input_composite_axes = input_composite_axes | ||||
|         self.output_composite_axes = result_axes_grouping | ||||
|         self.elementary_axes_lengths = list(axis_name2known_length.values()) | ||||
|         self.axes_permutation = axes_permutation | ||||
|  | ||||
|     @functools.lru_cache(maxsize=1024) | ||||
|     def reconstruct_from_shape(self, shape): | ||||
|         if len(shape) != len(self.input_composite_axes): | ||||
|             raise ValueError( | ||||
|                 "Expected {} dimensions, got {}".format( | ||||
|                     len(self.input_composite_axes), len(shape) | ||||
|                 ) | ||||
|             ) | ||||
|         axes_lengths = list(self.elementary_axes_lengths) | ||||
|         for input_axis, (known_axes, unknown_axes) in enumerate( | ||||
|             self.input_composite_axes | ||||
|         ): | ||||
|             length = shape[input_axis] | ||||
|             known_product = 1 | ||||
|             for axis in known_axes: | ||||
|                 known_product *= axes_lengths[axis] | ||||
|             if len(unknown_axes) == 0: | ||||
|                 if ( | ||||
|                     isinstance(length, int) | ||||
|                     and isinstance(known_product, int) | ||||
|                     and length != known_product | ||||
|                 ): | ||||
|                     raise ValueError( | ||||
|                         "Shape mismatch, {} != {}".format(length, known_product) | ||||
|                     ) | ||||
|             else: | ||||
|                 if ( | ||||
|                     isinstance(length, int) | ||||
|                     and isinstance(known_product, int) | ||||
|                     and length % known_product != 0 | ||||
|                 ): | ||||
|                     raise ValueError( | ||||
|                         "Shape mismatch, can't divide axis of length {} in chunks of {}".format( | ||||
|                             length, known_product | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
|                 (unknown_axis,) = unknown_axes | ||||
|                 axes_lengths[unknown_axis] = length // known_product | ||||
|         # at this point all axes_lengths are computed (either have values or variables, but not Nones) | ||||
|         final_shape = [] | ||||
|         for output_axis, grouping in enumerate(self.output_composite_axes): | ||||
|             lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] | ||||
|             final_shape.append(int(np.prod(lengths))) | ||||
|         axes_reordering = self.axes_permutation | ||||
|         return axes_lengths, axes_reordering, final_shape | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
| @@ -46,10 +172,13 @@ class SuperReArrange(SuperModule): | ||||
|         self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         raise NotImplementedError | ||||
|         init_shape, axes_reordering, final_shape = self.reconstruct_from_shape( | ||||
|             tuple(input.shape) | ||||
|         ) | ||||
|         tensor = torch.reshape(input, init_shape) | ||||
|         tensor = tensor.permute(axes_reordering) | ||||
|         tensor = torch.reshape(tensor, final_shape) | ||||
|         return tensor | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         params = repr(self._pattern) | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| opyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # | ||||
| ##################################################### | ||||
| import math | ||||
| from functools import partial | ||||
| from typing import Optional, Text, List | ||||
|   | ||||
		Reference in New Issue
	
	Block a user