Complete xlayers.rearrange

This commit is contained in:
D-X-Y 2021-06-08 23:47:52 -07:00
parent f9bbf974de
commit 744ce97bc5
10 changed files with 218 additions and 96 deletions

View File

@ -29,6 +29,7 @@ class TestSuperSelfAttention(unittest.TestCase):
abstract_child = abstract_space.random(reuse_last=True) abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child)) print("The abstract child program is:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.enable_candidate()
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
outputs = model(inputs) outputs = model(inputs)
return abstract_child, outputs return abstract_child, outputs

View File

@ -25,6 +25,7 @@ def _internal_func(inputs, model):
abstract_space.clean_last() abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True) abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child)) print("The abstract child program is:\n{:}".format(abstract_child))
model.enable_candidate()
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
outputs = model(inputs) outputs = model(inputs)

View File

@ -37,6 +37,7 @@ class TestSuperLinear(unittest.TestCase):
print("The abstract child program:\n{:}".format(abstract_child)) print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.enable_candidate()
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
output_shape = (20, abstract_child["_out_features"].value) output_shape = (20, abstract_child["_out_features"].value)
@ -77,6 +78,7 @@ class TestSuperLinear(unittest.TestCase):
) )
mlp.set_super_run_type(super_core.SuperRunMode.Candidate) mlp.set_super_run_type(super_core.SuperRunMode.Candidate)
mlp.enable_candidate()
mlp.apply_candidate(abstract_child) mlp.apply_candidate(abstract_child)
outputs = mlp(inputs) outputs = mlp(inputs)
output_shape = (4, abstract_child["fc2"]["_out_features"].value) output_shape = (4, abstract_child["fc2"]["_out_features"].value)
@ -103,6 +105,7 @@ class TestSuperLinear(unittest.TestCase):
print("The abstract child program is:\n{:}".format(abstract_child)) print("The abstract child program is:\n{:}".format(abstract_child))
mlp.set_super_run_type(super_core.SuperRunMode.Candidate) mlp.set_super_run_type(super_core.SuperRunMode.Candidate)
mlp.enable_candidate()
mlp.apply_candidate(abstract_child) mlp.apply_candidate(abstract_child)
outputs = mlp(inputs) outputs = mlp(inputs)
output_shape = (4, abstract_child["_out_features"].value) output_shape = (4, abstract_child["_out_features"].value)
@ -120,6 +123,7 @@ class TestSuperLinear(unittest.TestCase):
print("The abstract child program:\n{:}".format(abstract_child)) print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.enable_candidate()
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
outputs = model(inputs) outputs = model(inputs)
output_shape = (4, 60, abstract_child["_embed_dim"].value) output_shape = (4, 60, abstract_child["_embed_dim"].value)

View File

@ -38,6 +38,7 @@ class TestSuperSimpleNorm(unittest.TestCase):
print("The abstract child program:\n{:}".format(abstract_child)) print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.enable_candidate()
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
output_shape = (20, abstract_child["1"]["_out_features"].value) output_shape = (20, abstract_child["1"]["_out_features"].value)
@ -70,6 +71,7 @@ class TestSuperSimpleNorm(unittest.TestCase):
print("The abstract child program:\n{:}".format(abstract_child)) print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate) model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.enable_candidate()
model.apply_candidate(abstract_child) model.apply_candidate(abstract_child)
output_shape = (20, abstract_child["2"]["_out_features"].value) output_shape = (20, abstract_child["2"]["_out_features"].value)

View File

@ -5,12 +5,6 @@
##################################################### #####################################################
import sys import sys
import unittest import unittest
from pathlib import Path
lib_dir = (Path(__file__).parent / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import torch import torch
from xautodl import xlayers from xautodl import xlayers
@ -28,3 +22,4 @@ class TestSuperReArrange(unittest.TestCase):
print(layer) print(layer)
outs = layer(tensor) outs = layer(tensor)
print("The output tensor shape: {:}".format(outs.shape)) print("The output tensor shape: {:}".format(outs.shape))
assert tuple(outs.shape) == (8, 32 * 32 // 16, 4 * 4 * 4)

View File

@ -1,36 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest ./tests/test_super_model.py -s #
#####################################################
import unittest
import torch
from xautodl.xlayers.super_core import SuperRunMode
from xautodl.trade_models import get_transformer
class TestSuperTransformer(unittest.TestCase):
"""Test the super transformer."""
def test_super_transformer(self):
model = get_transformer(None)
model.apply_verbose(False)
print(model)
inputs = torch.rand(10, 360)
print("Input shape: {:}".format(inputs.shape))
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), (10,))
abstract_space = model.abstract_search_space
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract searc space:\n{:}".format(abstract_space))
print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(SuperRunMode.Candidate)
model.apply_candidate(abstract_child)
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), (10,))

View File

@ -1,4 +1,28 @@
# borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py # borrowed from https://github.com/arogozhnikov/einops/blob/master/einops/parsing.py
import warnings
import keyword
from typing import List
class AnonymousAxis:
"""Important thing: all instances of this class are not equal to each other"""
def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError(
"No need to create anonymous axis of length 1. Report this as an issue"
)
else:
raise EinopsError(
"Anonymous axis should have positive length, not {}".format(
self.value
)
)
def __repr__(self):
return "{}-axis".format(str(self.value))
class ParsedExpression: class ParsedExpression:
@ -8,24 +32,13 @@ class ParsedExpression:
""" """
def __init__(self, expression): def __init__(self, expression):
self.has_ellipsis = False
self.has_ellipsis_parenthesized = None
self.identifiers = set() self.identifiers = set()
# that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition # that's axes like 2, 3 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes = False self.has_non_unitary_anonymous_axes = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests # composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition = [] self.composition = []
if "." in expression: if "." in expression:
if "..." not in expression: raise ValueError("Does not support . in the expression.")
raise ValueError(
"Expression may contain dots only inside ellipsis (...)"
)
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
raise ValueError(
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
)
expression = expression.replace("...", _ellipsis)
self.has_ellipsis = True
bracket_group = None bracket_group = None
@ -37,37 +50,28 @@ class ParsedExpression:
x x
) )
) )
if x == _ellipsis: is_number = str.isdecimal(x)
self.identifiers.add(_ellipsis) if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None: if bracket_group is None:
self.composition.append(_ellipsis) self.composition.append([])
self.has_ellipsis_parenthesized = False
else: else:
bracket_group.append(_ellipsis) pass # no need to think about 1s inside parenthesis
self.has_ellipsis_parenthesized = True return
is_axis_name, reason = self.check_axis_name(x, return_reason=True)
if not (is_number or is_axis_name):
raise ValueError(
"Invalid axis identifier: {}\n{}".format(x, reason)
)
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else: else:
is_number = str.isdecimal(x) bracket_group.append(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name(x, return_reason=True)
if not (is_number or is_axis_name):
raise ValueError(
"Invalid axis identifier: {}\n{}".format(x, reason)
)
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)
current_identifier = None current_identifier = None
for char in expression: for char in expression:
@ -85,7 +89,7 @@ class ParsedExpression:
raise ValueError("Brackets are not balanced") raise ValueError("Brackets are not balanced")
self.composition.append(bracket_group) self.composition.append(bracket_group)
bracket_group = None bracket_group = None
elif str.isalnum(char) or char in ["_", _ellipsis]: elif str.isalnum(char) or char == "_":
if current_identifier is None: if current_identifier is None:
current_identifier = char current_identifier = char
else: else:
@ -143,3 +147,8 @@ class ParsedExpression:
return result return result
else: else:
return result[0] return result[0]
def __repr__(self) -> str:
return "{name}({composition})".format(
name=self.__class__.__name__, composition=self.composition
)

View File

@ -21,6 +21,8 @@ from .super_utils import ShapeContainer
BEST_DIR_KEY = "best_model_dir" BEST_DIR_KEY = "best_model_dir"
BEST_NAME_KEY = "best_model_name" BEST_NAME_KEY = "best_model_name"
BEST_SCORE_KEY = "best_model_score" BEST_SCORE_KEY = "best_model_score"
ENABLE_CANDIDATE = 0
DISABLE_CANDIDATE = 1
class SuperModule(abc.ABC, nn.Module): class SuperModule(abc.ABC, nn.Module):
@ -32,6 +34,7 @@ class SuperModule(abc.ABC, nn.Module):
self._abstract_child = None self._abstract_child = None
self._verbose = False self._verbose = False
self._meta_info = {} self._meta_info = {}
self._candidate_mode = DISABLE_CANDIDATE
def set_super_run_type(self, super_run_type): def set_super_run_type(self, super_run_type):
def _reset_super_run(m): def _reset_super_run(m):
@ -65,6 +68,20 @@ class SuperModule(abc.ABC, nn.Module):
) )
self._abstract_child = abstract_child self._abstract_child = abstract_child
def enable_candidate(self):
def _enable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = ENABLE_CANDIDATE
self.apply(_enable_candidate)
def disable_candidate(self):
def _disable_candidate(m):
if isinstance(m, SuperModule):
m._candidate_mode = DISABLE_CANDIDATE
self.apply(_disable_candidate)
def get_w_container(self): def get_w_container(self):
container = TensorContainer() container = TensorContainer()
for name, param in self.named_parameters(): for name, param in self.named_parameters():
@ -191,9 +208,11 @@ class SuperModule(abc.ABC, nn.Module):
if self.super_run_type == SuperRunMode.FullModel: if self.super_run_type == SuperRunMode.FullModel:
outputs = self.forward_raw(*inputs) outputs = self.forward_raw(*inputs)
elif self.super_run_type == SuperRunMode.Candidate: elif self.super_run_type == SuperRunMode.Candidate:
if self._candidate_mode == DISABLE_CANDIDATE:
raise ValueError("candidate mode is disabled")
outputs = self.forward_candidate(*inputs) outputs = self.forward_candidate(*inputs)
else: else:
raise ModeError( raise ValueError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type) "Unknown Super Model Run Mode: {:}".format(self.super_run_type)
) )
if self.verbose: if self.verbose:

View File

@ -8,10 +8,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import numpy as np
import itertools
import functools
from collections import OrderedDict
from typing import Optional, Callable from typing import Optional, Callable
from xautodl import spaces from xautodl import spaces
from .misc_utils import ParsedExpression from .misc_utils import ParsedExpression, AnonymousAxis
from .super_module import SuperModule from .super_module import SuperModule
from .super_module import IntSpaceType from .super_module import IntSpaceType
from .super_module import BoolSpaceType from .super_module import BoolSpaceType
@ -31,11 +35,133 @@ class SuperReArrange(SuperModule):
left, right = pattern.split("->") left, right = pattern.split("->")
left = ParsedExpression(left) left = ParsedExpression(left)
right = ParsedExpression(right) right = ParsedExpression(right)
difference = set.symmetric_difference(left.identifiers, right.identifiers)
if difference:
raise ValueError(
"Identifiers only on one side of expression (should be on both): {}".format(
difference
)
)
import pdb # parsing all dimensions to find out lengths
axis_name2known_length = OrderedDict()
for composite_axis in left.composition:
for axis_name in composite_axis:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = None
for axis_name in right.identifiers:
if axis_name not in axis_name2known_length:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = None
pdb.set_trace() axis_name2position = {
print("-") name: position for position, name in enumerate(axis_name2known_length)
}
for elementary_axis, axis_length in axes_lengths:
if not ParsedExpression.check_axis_name(elementary_axis):
raise ValueError("Invalid name for an axis", elementary_axis)
if elementary_axis not in axis_name2known_length:
raise ValueError(
"Axis {} is not used in transform".format(elementary_axis)
)
axis_name2known_length[elementary_axis] = axis_length
input_composite_axes = []
# some of shapes will be inferred later - all information is prepared for faster inference
for composite_axis in left.composition:
known = {
axis
for axis in composite_axis
if axis_name2known_length[axis] is not None
}
unknown = {
axis for axis in composite_axis if axis_name2known_length[axis] is None
}
if len(unknown) > 1:
raise ValueError("Could not infer sizes for {}".format(unknown))
assert len(unknown) + len(known) == len(composite_axis)
input_composite_axes.append(
(
[axis_name2position[axis] for axis in known],
[axis_name2position[axis] for axis in unknown],
)
)
axis_position_after_reduction = {}
for axis_name in itertools.chain(*left.composition):
if axis_name in right.identifiers:
axis_position_after_reduction[axis_name] = len(
axis_position_after_reduction
)
result_axes_grouping = []
for composite_axis in right.composition:
result_axes_grouping.append(
[axis_name2position[axis] for axis in composite_axis]
)
ordered_axis_right = list(itertools.chain(*right.composition))
axes_permutation = tuple(
axis_position_after_reduction[axis]
for axis in ordered_axis_right
if axis in left.identifiers
)
#
self.input_composite_axes = input_composite_axes
self.output_composite_axes = result_axes_grouping
self.elementary_axes_lengths = list(axis_name2known_length.values())
self.axes_permutation = axes_permutation
@functools.lru_cache(maxsize=1024)
def reconstruct_from_shape(self, shape):
if len(shape) != len(self.input_composite_axes):
raise ValueError(
"Expected {} dimensions, got {}".format(
len(self.input_composite_axes), len(shape)
)
)
axes_lengths = list(self.elementary_axes_lengths)
for input_axis, (known_axes, unknown_axes) in enumerate(
self.input_composite_axes
):
length = shape[input_axis]
known_product = 1
for axis in known_axes:
known_product *= axes_lengths[axis]
if len(unknown_axes) == 0:
if (
isinstance(length, int)
and isinstance(known_product, int)
and length != known_product
):
raise ValueError(
"Shape mismatch, {} != {}".format(length, known_product)
)
else:
if (
isinstance(length, int)
and isinstance(known_product, int)
and length % known_product != 0
):
raise ValueError(
"Shape mismatch, can't divide axis of length {} in chunks of {}".format(
length, known_product
)
)
(unknown_axis,) = unknown_axes
axes_lengths[unknown_axis] = length // known_product
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
final_shape = []
for output_axis, grouping in enumerate(self.output_composite_axes):
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
final_shape.append(int(np.prod(lengths)))
axes_reordering = self.axes_permutation
return axes_lengths, axes_reordering, final_shape
@property @property
def abstract_search_space(self): def abstract_search_space(self):
@ -46,10 +172,13 @@ class SuperReArrange(SuperModule):
self.forward_raw(input) self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor: def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
import pdb init_shape, axes_reordering, final_shape = self.reconstruct_from_shape(
tuple(input.shape)
pdb.set_trace() )
raise NotImplementedError tensor = torch.reshape(input, init_shape)
tensor = tensor.permute(axes_reordering)
tensor = torch.reshape(tensor, final_shape)
return tensor
def extra_repr(self) -> str: def extra_repr(self) -> str:
params = repr(self._pattern) params = repr(self._pattern)

View File

@ -1,8 +1,6 @@
opyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
##################################################### #####################################################
from __future__ import division # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
from __future__ import print_function #####################################################
import math import math
from functools import partial from functools import partial
from typing import Optional, Text, List from typing import Optional, Text, List