diff --git a/tests/test_super_att.py b/tests/test_super_att.py index fc55b1c..6df7e33 100644 --- a/tests/test_super_att.py +++ b/tests/test_super_att.py @@ -29,6 +29,7 @@ class TestSuperSelfAttention(unittest.TestCase): abstract_child = abstract_space.random(reuse_last=True) print("The abstract child program is:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.enable_candidate() model.apply_candidate(abstract_child) outputs = model(inputs) return abstract_child, outputs diff --git a/tests/test_super_container.py b/tests/test_super_container.py index 6cdc687..37e4523 100644 --- a/tests/test_super_container.py +++ b/tests/test_super_container.py @@ -25,6 +25,7 @@ def _internal_func(inputs, model): abstract_space.clean_last() abstract_child = abstract_space.random(reuse_last=True) print("The abstract child program is:\n{:}".format(abstract_child)) + model.enable_candidate() model.set_super_run_type(super_core.SuperRunMode.Candidate) model.apply_candidate(abstract_child) outputs = model(inputs) diff --git a/tests/test_super_mlp.py b/tests/test_super_mlp.py index aba8792..b60a68c 100644 --- a/tests/test_super_mlp.py +++ b/tests/test_super_mlp.py @@ -37,6 +37,7 @@ class TestSuperLinear(unittest.TestCase): print("The abstract child program:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.enable_candidate() model.apply_candidate(abstract_child) output_shape = (20, abstract_child["_out_features"].value) @@ -77,6 +78,7 @@ class TestSuperLinear(unittest.TestCase): ) mlp.set_super_run_type(super_core.SuperRunMode.Candidate) + mlp.enable_candidate() mlp.apply_candidate(abstract_child) outputs = mlp(inputs) output_shape = (4, abstract_child["fc2"]["_out_features"].value) @@ -103,6 +105,7 @@ class TestSuperLinear(unittest.TestCase): print("The abstract child program is:\n{:}".format(abstract_child)) mlp.set_super_run_type(super_core.SuperRunMode.Candidate) + mlp.enable_candidate() mlp.apply_candidate(abstract_child) outputs = mlp(inputs) output_shape = (4, abstract_child["_out_features"].value) @@ -120,6 +123,7 @@ class TestSuperLinear(unittest.TestCase): print("The abstract child program:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.enable_candidate() model.apply_candidate(abstract_child) outputs = model(inputs) output_shape = (4, 60, abstract_child["_embed_dim"].value) diff --git a/tests/test_super_norm.py b/tests/test_super_norm.py index 01a309f..fc3a5ab 100644 --- a/tests/test_super_norm.py +++ b/tests/test_super_norm.py @@ -38,6 +38,7 @@ class TestSuperSimpleNorm(unittest.TestCase): print("The abstract child program:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.enable_candidate() model.apply_candidate(abstract_child) output_shape = (20, abstract_child["1"]["_out_features"].value) @@ -70,6 +71,7 @@ class TestSuperSimpleNorm(unittest.TestCase): print("The abstract child program:\n{:}".format(abstract_child)) model.set_super_run_type(super_core.SuperRunMode.Candidate) + model.enable_candidate() model.apply_candidate(abstract_child) output_shape = (20, abstract_child["2"]["_out_features"].value) diff --git a/tests/test_super_rearrange.py b/tests/test_super_rearrange.py index bd2d35b..3b86d37 100644 --- a/tests/test_super_rearrange.py +++ b/tests/test_super_rearrange.py @@ -5,12 +5,6 @@ ##################################################### import sys import unittest -from pathlib import Path - -lib_dir = (Path(__file__).parent / "..").resolve() -print("LIB-DIR: {:}".format(lib_dir)) -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) import torch from xautodl import xlayers @@ -28,3 +22,4 @@ class TestSuperReArrange(unittest.TestCase): print(layer) outs = layer(tensor) print("The output tensor shape: {:}".format(outs.shape)) + assert tuple(outs.shape) == (8, 32 * 32 // 16, 4 * 4 * 4) diff --git a/tests/test_super_transformer.py b/tests/test_super_transformer.py deleted file mode 100644 index 23fd747..0000000 --- a/tests/test_super_transformer.py +++ /dev/null @@ -1,36 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### -# pytest ./tests/test_super_model.py -s # -##################################################### -import unittest - -import torch -from xautodl.xlayers.super_core import SuperRunMode -from xautodl.trade_models import get_transformer - - -class TestSuperTransformer(unittest.TestCase): - """Test the super transformer.""" - - def test_super_transformer(self): - model = get_transformer(None) - model.apply_verbose(False) - print(model) - - inputs = torch.rand(10, 360) - print("Input shape: {:}".format(inputs.shape)) - outputs = model(inputs) - self.assertEqual(tuple(outputs.shape), (10,)) - - abstract_space = model.abstract_search_space - abstract_space.clean_last() - abstract_child = abstract_space.random(reuse_last=True) - print("The abstract searc space:\n{:}".format(abstract_space)) - print("The abstract child program:\n{:}".format(abstract_child)) - - model.set_super_run_type(SuperRunMode.Candidate) - model.apply_candidate(abstract_child) - - outputs = model(inputs) - self.assertEqual(tuple(outputs.shape), (10,)) diff --git a/xautodl/xlayers/misc_utils.py b/xautodl/xlayers/misc_utils.py index 564cd7f..086f3d7 100644 --- a/xautodl/xlayers/misc_utils.py +++ b/xautodl/xlayers/misc_utils.py @@ -1,4 +1,28 @@ # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py +import warnings +import keyword +from typing import List + + +class AnonymousAxis: + """Important thing: all instances of this class are not equal to each other""" + + def __init__(self, value: str): + self.value = int(value) + if self.value <= 1: + if self.value == 1: + raise EinopsError( + "No need to create anonymous axis of length 1. Report this as an issue" + ) + else: + raise EinopsError( + "Anonymous axis should have positive length, not {}".format( + self.value + ) + ) + + def __repr__(self): + return "{}-axis".format(str(self.value)) class ParsedExpression: @@ -8,24 +32,13 @@ class ParsedExpression: """ def __init__(self, expression): - self.has_ellipsis = False - self.has_ellipsis_parenthesized = None self.identifiers = set() # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition self.has_non_unitary_anonymous_axes = False # composition keeps structure of composite axes, see how different corner cases are handled in tests self.composition = [] if "." in expression: - if "..." not in expression: - raise ValueError( - "Expression may contain dots only inside ellipsis (...)" - ) - if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: - raise ValueError( - "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " - ) - expression = expression.replace("...", _ellipsis) - self.has_ellipsis = True + raise ValueError("Does not support . in the expression.") bracket_group = None @@ -37,37 +50,28 @@ class ParsedExpression: x ) ) - if x == _ellipsis: - self.identifiers.add(_ellipsis) + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 if bracket_group is None: - self.composition.append(_ellipsis) - self.has_ellipsis_parenthesized = False + self.composition.append([]) else: - bracket_group.append(_ellipsis) - self.has_ellipsis_parenthesized = True + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name(x, return_reason=True) + if not (is_number or is_axis_name): + raise ValueError( + "Invalid axis identifier: {}\n{}".format(x, reason) + ) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) else: - is_number = str.isdecimal(x) - if is_number and int(x) == 1: - # handling the case of anonymous axis of length 1 - if bracket_group is None: - self.composition.append([]) - else: - pass # no need to think about 1s inside parenthesis - return - is_axis_name, reason = self.check_axis_name(x, return_reason=True) - if not (is_number or is_axis_name): - raise ValueError( - "Invalid axis identifier: {}\n{}".format(x, reason) - ) - if is_number: - x = AnonymousAxis(x) - self.identifiers.add(x) - if is_number: - self.has_non_unitary_anonymous_axes = True - if bracket_group is None: - self.composition.append([x]) - else: - bracket_group.append(x) + bracket_group.append(x) current_identifier = None for char in expression: @@ -85,7 +89,7 @@ class ParsedExpression: raise ValueError("Brackets are not balanced") self.composition.append(bracket_group) bracket_group = None - elif str.isalnum(char) or char in ["_", _ellipsis]: + elif str.isalnum(char) or char == "_": if current_identifier is None: current_identifier = char else: @@ -143,3 +147,8 @@ class ParsedExpression: return result else: return result[0] + + def __repr__(self) -> str: + return "{name}({composition})".format( + name=self.__class__.__name__, composition=self.composition + ) diff --git a/xautodl/xlayers/super_module.py b/xautodl/xlayers/super_module.py index f08461f..0ba4b75 100644 --- a/xautodl/xlayers/super_module.py +++ b/xautodl/xlayers/super_module.py @@ -21,6 +21,8 @@ from .super_utils import ShapeContainer BEST_DIR_KEY = "best_model_dir" BEST_NAME_KEY = "best_model_name" BEST_SCORE_KEY = "best_model_score" +ENABLE_CANDIDATE = 0 +DISABLE_CANDIDATE = 1 class SuperModule(abc.ABC, nn.Module): @@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module): self._abstract_child = None self._verbose = False self._meta_info = {} + self._candidate_mode = DISABLE_CANDIDATE def set_super_run_type(self, super_run_type): def _reset_super_run(m): @@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module): ) self._abstract_child = abstract_child + def enable_candidate(self): + def _enable_candidate(m): + if isinstance(m, SuperModule): + m._candidate_mode = ENABLE_CANDIDATE + + self.apply(_enable_candidate) + + def disable_candidate(self): + def _disable_candidate(m): + if isinstance(m, SuperModule): + m._candidate_mode = DISABLE_CANDIDATE + + self.apply(_disable_candidate) + def get_w_container(self): container = TensorContainer() for name, param in self.named_parameters(): @@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module): if self.super_run_type == SuperRunMode.FullModel: outputs = self.forward_raw(*inputs) elif self.super_run_type == SuperRunMode.Candidate: + if self._candidate_mode == DISABLE_CANDIDATE: + raise ValueError("candidate mode is disabled") outputs = self.forward_candidate(*inputs) else: - raise ModeError( + raise ValueError( "Unknown Super Model Run Mode: {:}".format(self.super_run_type) ) if self.verbose: diff --git a/xautodl/xlayers/super_rearrange.py b/xautodl/xlayers/super_rearrange.py index 76e34e7..ff9ff35 100644 --- a/xautodl/xlayers/super_rearrange.py +++ b/xautodl/xlayers/super_rearrange.py @@ -8,10 +8,14 @@ import torch.nn as nn import torch.nn.functional as F import math +import numpy as np +import itertools +import functools +from collections import OrderedDict from typing import Optional, Callable from xautodl import spaces -from .misc_utils import ParsedExpression +from .misc_utils import ParsedExpression, AnonymousAxis from .super_module import SuperModule from .super_module import IntSpaceType from .super_module import BoolSpaceType @@ -31,11 +35,133 @@ class SuperReArrange(SuperModule): left, right = pattern.split("->") left = ParsedExpression(left) right = ParsedExpression(right) + difference = set.symmetric_difference(left.identifiers, right.identifiers) + if difference: + raise ValueError( + "Identifiers only on one side of expression (should be on both): {}".format( + difference + ) + ) - import pdb + # parsing all dimensions to find out lengths + axis_name2known_length = OrderedDict() + for composite_axis in left.composition: + for axis_name in composite_axis: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = None + for axis_name in right.identifiers: + if axis_name not in axis_name2known_length: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = None - pdb.set_trace() - print("-") + axis_name2position = { + name: position for position, name in enumerate(axis_name2known_length) + } + for elementary_axis, axis_length in axes_lengths: + if not ParsedExpression.check_axis_name(elementary_axis): + raise ValueError("Invalid name for an axis", elementary_axis) + if elementary_axis not in axis_name2known_length: + raise ValueError( + "Axis {} is not used in transform".format(elementary_axis) + ) + axis_name2known_length[elementary_axis] = axis_length + + input_composite_axes = [] + # some of shapes will be inferred later - all information is prepared for faster inference + for composite_axis in left.composition: + known = { + axis + for axis in composite_axis + if axis_name2known_length[axis] is not None + } + unknown = { + axis for axis in composite_axis if axis_name2known_length[axis] is None + } + if len(unknown) > 1: + raise ValueError("Could not infer sizes for {}".format(unknown)) + assert len(unknown) + len(known) == len(composite_axis) + input_composite_axes.append( + ( + [axis_name2position[axis] for axis in known], + [axis_name2position[axis] for axis in unknown], + ) + ) + + axis_position_after_reduction = {} + for axis_name in itertools.chain(*left.composition): + if axis_name in right.identifiers: + axis_position_after_reduction[axis_name] = len( + axis_position_after_reduction + ) + + result_axes_grouping = [] + for composite_axis in right.composition: + result_axes_grouping.append( + [axis_name2position[axis] for axis in composite_axis] + ) + + ordered_axis_right = list(itertools.chain(*right.composition)) + axes_permutation = tuple( + axis_position_after_reduction[axis] + for axis in ordered_axis_right + if axis in left.identifiers + ) + # + self.input_composite_axes = input_composite_axes + self.output_composite_axes = result_axes_grouping + self.elementary_axes_lengths = list(axis_name2known_length.values()) + self.axes_permutation = axes_permutation + + @functools.lru_cache(maxsize=1024) + def reconstruct_from_shape(self, shape): + if len(shape) != len(self.input_composite_axes): + raise ValueError( + "Expected {} dimensions, got {}".format( + len(self.input_composite_axes), len(shape) + ) + ) + axes_lengths = list(self.elementary_axes_lengths) + for input_axis, (known_axes, unknown_axes) in enumerate( + self.input_composite_axes + ): + length = shape[input_axis] + known_product = 1 + for axis in known_axes: + known_product *= axes_lengths[axis] + if len(unknown_axes) == 0: + if ( + isinstance(length, int) + and isinstance(known_product, int) + and length != known_product + ): + raise ValueError( + "Shape mismatch, {} != {}".format(length, known_product) + ) + else: + if ( + isinstance(length, int) + and isinstance(known_product, int) + and length % known_product != 0 + ): + raise ValueError( + "Shape mismatch, can't divide axis of length {} in chunks of {}".format( + length, known_product + ) + ) + + (unknown_axis,) = unknown_axes + axes_lengths[unknown_axis] = length // known_product + # at this point all axes_lengths are computed (either have values or variables, but not Nones) + final_shape = [] + for output_axis, grouping in enumerate(self.output_composite_axes): + lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] + final_shape.append(int(np.prod(lengths))) + axes_reordering = self.axes_permutation + return axes_lengths, axes_reordering, final_shape @property def abstract_search_space(self): @@ -46,10 +172,13 @@ class SuperReArrange(SuperModule): self.forward_raw(input) def forward_raw(self, input: torch.Tensor) -> torch.Tensor: - import pdb - - pdb.set_trace() - raise NotImplementedError + init_shape, axes_reordering, final_shape = self.reconstruct_from_shape( + tuple(input.shape) + ) + tensor = torch.reshape(input, init_shape) + tensor = tensor.permute(axes_reordering) + tensor = torch.reshape(tensor, final_shape) + return tensor def extra_repr(self) -> str: params = repr(self._pattern) diff --git a/xautodl/xmodels/transformers.py b/xautodl/xmodels/transformers.py index 8a53947..09dea0a 100644 --- a/xautodl/xmodels/transformers.py +++ b/xautodl/xmodels/transformers.py @@ -1,8 +1,6 @@ -opyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # ##################################################### -from __future__ import division -from __future__ import print_function - +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 # +##################################################### import math from functools import partial from typing import Optional, Text, List