Update super cores

This commit is contained in:
D-X-Y 2021-03-18 18:32:26 +08:00
parent 63c8bb9bc8
commit eabdd21d97
9 changed files with 209 additions and 18 deletions

View File

@ -43,5 +43,5 @@ jobs:
echo "Show what we have here:"
ls
python --version
python -m pytest ./tests --durations=0
python -m pytest ./tests -s
shell: bash

5
lib/layers/super_core.py Normal file
View File

@ -0,0 +1,5 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from .super_module import SuperModule
from .super_mlp import SuperLinear

View File

@ -1,38 +1,71 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch.nn as nn
from torch.nn.parameter import Parameter
from typing import Optional
from torch import Tensor
import math
from typing import Optional, Union
import spaces
from layers.super_module import SuperModule
from layers.super_module import SuperModule
from layers.super_module import SuperRunType
IntSpaceType = Union[int, spaces.Integer, spaces.Categorical]
BoolSpaceType = Union[bool, spaces.Categorical]
class SuperLinear(SuperModule):
"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`"""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
def __init__(
self,
in_features: IntSpaceType,
out_features: IntSpaceType,
bias: BoolSpaceType = True,
) -> None:
super(SuperLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
# the raw input args
self._in_features = in_features
self._out_features = out_features
self._bias = bias
self._super_weight = Parameter(
torch.Tensor(self.out_features, self.in_features)
)
if bias:
self.bias = Parameter(torch.Tensor(out_features))
self._super_bias = Parameter(torch.Tensor(self.out_features))
else:
self.register_parameter("bias", None)
self.register_parameter("_super_bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@property
def in_features(self):
return spaces.get_max(self._in_features)
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
@property
def out_features(self):
return spaces.get_max(self._out_features)
@property
def bias(self):
return spaces.has_categorical(self._bias, True)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
if self.bias:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._super_weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self._super_bias, -bound, bound)
def forward_raw(self, input: Tensor) -> Tensor:
return F.linear(input, self._super_weight, self._super_bias)
def extra_repr(self) -> str:
return "in_features={:}, out_features={:}, bias={:}".format(
self.in_features, self.out_features, self.bias is not None
self.in_features, self.out_features, self.bias
)

View File

@ -4,6 +4,14 @@
import abc
import torch.nn as nn
from enum import Enum
class SuperRunMode(Enum):
"""This class defines the enumerations for Super Model Running Mode."""
FullModel = "fullmodel"
Default = "fullmodel"
class SuperModule(abc.ABCMeta, nn.Module):
@ -11,7 +19,24 @@ class SuperModule(abc.ABCMeta, nn.Module):
def __init__(self):
super(SuperModule, self).__init__()
self._super_run_type = SuperRunMode.default
@abc.abstractmethod
def abstract_search_space(self):
raise NotImplementedError
@property
def super_run_type(self):
return self._super_run_type
@abc.abstractmethod
def forward_raw(self, *inputs):
raise NotImplementedError
def forward(self, *inputs):
if self.super_run_type == SuperRunMode.FullModel:
return self.forward_raw(*inputs)
else:
raise ModeError(
"Unknown Super Model Run Mode: {:}".format(self.super_run_type)
)

View File

@ -9,3 +9,5 @@ from .basic_space import Continuous
from .basic_space import Integer
from .basic_op import has_categorical
from .basic_op import has_continuous
from .basic_op import get_min
from .basic_op import get_max

View File

@ -1,4 +1,7 @@
from spaces.basic_space import Space
from spaces.basic_space import Integer
from spaces.basic_space import Continuous
from spaces.basic_space import Categorical
from spaces.basic_space import _EPS
@ -14,3 +17,33 @@ def has_continuous(space_or_value, x):
return space_or_value.has(x)
else:
return abs(space_or_value - x) <= _EPS
def get_max(space_or_value):
if isinstance(space_or_value, Integer):
return max(space_or_value.candidates)
elif isinstance(space_or_value, Continuous):
return space_or_value.upper
elif isinstance(space_or_value, Categorical):
values = []
for index in range(len(space_or_value)):
max_value = get_max(space_or_value[index])
values.append(max_value)
return max(values)
else:
return space_or_value
def get_min(space_or_value):
if isinstance(space_or_value, Integer):
return min(space_or_value.candidates)
elif isinstance(space_or_value, Continuous):
return space_or_value.lower
elif isinstance(space_or_value, Categorical):
values = []
for index in range(len(space_or_value)):
min_value = get_min(space_or_value[index])
values.append(min_value)
return min(values)
else:
return space_or_value

View File

@ -10,6 +10,9 @@ import numpy as np
from typing import Optional
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
_EPS = 1e-9
@ -54,6 +57,10 @@ class Categorical(Space):
), "default >= {:}".format(len(self._candidates))
assert len(self) > 0, "Please provide at least one candidate"
@property
def candidates(self):
return self._candidates
@property
def determined(self):
if len(self) == 1:

View File

@ -0,0 +1,80 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n"
]
}
],
"source": [
"#####################################################\n",
"# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n",
"#####################################################\n",
"import abc, os, sys\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"\n",
"lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n",
"print(\"library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SuperRunType.FullModel\n",
"SuperRunType.FullModel\n",
"True\n",
"True\n"
]
}
],
"source": [
"from layers.super_core import SuperLinear\n",
"from layers.super_module import SuperRunMode\n",
"\n",
"print(SuperRunMode.Default)\n",
"print(SuperRunMode.FullModel)\n",
"print(SuperRunMode.Default == SuperRunMode.FullModel)\n",
"print(SuperRunMode.FullModel == SuperRunMode.FullModel)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -14,6 +14,9 @@ if str(lib_dir) not in sys.path:
from spaces import Categorical
from spaces import Continuous
from spaces import Integer
from spaces import Integer
from spaces import get_min
from spaces import get_max
class TestBasicSpace(unittest.TestCase):
@ -32,6 +35,8 @@ class TestBasicSpace(unittest.TestCase):
for i in range(4):
self.assertEqual(space[i], i + 1)
self.assertEqual("Integer(lower=1, upper=4, default=None)", str(space))
self.assertEqual(get_max(space), 4)
self.assertEqual(get_min(space), 1)
def test_continuous(self):
random.seed(999)
@ -84,5 +89,6 @@ class TestBasicSpace(unittest.TestCase):
Categorical(4, Categorical(5, 6, 7, Categorical(8, 9), 10), 11),
12,
)
print(nested_space)
for i in range(1, 13):
self.assertTrue(nested_space.has(i))