Add unit tests for super-linear

This commit is contained in:
D-X-Y 2021-03-18 20:44:22 +08:00
parent badb6cf51d
commit ca22d61259
7 changed files with 57 additions and 17 deletions

View File

@ -32,7 +32,7 @@ jobs:
echo $PWD ; ls
python -m black ./exps -l 88 --check --diff --verbose
python -m black ./tests -l 88 --check --diff --verbose
python -m black ./lib/layers -l 88 --check --diff --verbose
python -m black ./lib/xlayers -l 88 --check --diff --verbose
python -m black ./lib/spaces -l 88 --check --diff --verbose
python -m black ./lib/trade_models -l 88 --check --diff --verbose

View File

@ -11,5 +11,6 @@ from .basic_space import Space
from .basic_space import VirtualNode
from .basic_op import has_categorical
from .basic_op import has_continuous
from .basic_op import is_determined
from .basic_op import get_min
from .basic_op import get_max

View File

@ -19,6 +19,13 @@ def has_continuous(space_or_value, x):
return abs(space_or_value - x) <= _EPS
def is_determined(space_or_value):
if isinstance(space_or_value, Space):
return space_or_value.determined
else:
return True
def get_max(space_or_value):
if isinstance(space_or_value, Integer):
return max(space_or_value.candidates)

View File

@ -30,10 +30,13 @@ class Space(metaclass=abc.ABCMeta):
def determined(self):
raise NotImplementedError
@abc.abstractmethod
def __repr__(self):
@abc.abstractproperty
def xrepr(self, indent=0):
raise NotImplementedError
def __repr__(self):
return self.xrepr()
@abc.abstractmethod
def has(self, x):
"""Check whether x is in this search space."""
@ -58,15 +61,28 @@ class VirtualNode(Space):
def has(self, x):
for key, value in self._attributes.items():
if isinstance(value, Space) and value.has(x):
if value.has(x):
return True
return False
def __repr__(self):
strs = [self.__class__.__name__ + "("]
indent = " " * 4
def append(self, key, value):
if not isinstance(value, Space):
raise ValueError("Invalid type of value: {:}".format(type(value)))
self._attributes[key] = value
def determined(self):
for key, value in self._attributes.items():
strs.append(indent + strs(value))
if not value.determined(x):
return False
return True
def random(self, recursion=True):
raise NotImplementedError
def xrepr(self, indent=0):
strs = [self.__class__.__name__ + "("]
for key, value in self._attributes.items():
strs.append(value.xrepr(indent + 2) + ",")
strs.append(")")
return "\n".join(strs)
@ -104,10 +120,11 @@ class Categorical(Space):
def __len__(self):
return len(self._candidates)
def __repr__(self):
return "{name:}(candidates={cs:}, default_index={default:})".format(
def xrepr(self, indent=0):
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
return " " * indent + xrepr
def has(self, x):
super().has(x)
@ -143,13 +160,14 @@ class Integer(Categorical):
default = data.index(default)
super(Integer, self).__init__(*data, default=default)
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
def xrepr(self, indent=0):
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
name=self.__class__.__name__,
lower=self._raw_lower,
upper=self._raw_upper,
default=self._raw_default,
)
return " " * indent + xrepr
np_float_types = (np.float16, np.float32, np.float64)
@ -198,14 +216,15 @@ class Continuous(Space):
def determined(self):
return abs(self.lower - self.upper) <= self._eps
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
def xrepr(self, indent=0):
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
return " " * indent + xrepr
def convert(self, x):
if isinstance(x, np_float_types) and x.size == 1:

View File

@ -30,7 +30,7 @@ class SuperLinear(SuperModule):
self._in_features = in_features
self._out_features = out_features
self._bias = bias
# weights to be optimized
self._super_weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features)
)
@ -53,7 +53,14 @@ class SuperLinear(SuperModule):
return spaces.has_categorical(self._bias, True)
def abstract_search_space(self):
print('-')
root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append("_in_features", self._in_features)
if not spaces.is_determined(self._out_features):
root_node.append("_out_features", self._out_features)
if not spaces.is_determined(self._bias):
root_node.append("_bias", self._bias)
return root_node
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))

View File

@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path:
from spaces import Categorical
from spaces import Continuous
from spaces import Integer
from spaces import Integer
from spaces import is_determined
from spaces import get_min
from spaces import get_max
@ -92,3 +92,7 @@ class TestBasicSpace(unittest.TestCase):
print(nested_space)
for i in range(1, 13):
self.assertTrue(nested_space.has(i))
# Test Simple Op
self.assertTrue(is_determined(1))
self.assertFalse(is_determined(nested_space))

View File

@ -26,3 +26,5 @@ class TestSuperLinear(unittest.TestCase):
bias = spaces.Categorical(True, False)
model = super_core.SuperLinear(10, out_features, bias=bias)
print(model)
print(model.super_run_type)
print(model.abstract_search_space())