Add unit tests for super-linear

This commit is contained in:
D-X-Y 2021-03-18 20:44:22 +08:00
parent badb6cf51d
commit ca22d61259
7 changed files with 57 additions and 17 deletions

View File

@ -32,7 +32,7 @@ jobs:
echo $PWD ; ls echo $PWD ; ls
python -m black ./exps -l 88 --check --diff --verbose python -m black ./exps -l 88 --check --diff --verbose
python -m black ./tests -l 88 --check --diff --verbose python -m black ./tests -l 88 --check --diff --verbose
python -m black ./lib/layers -l 88 --check --diff --verbose python -m black ./lib/xlayers -l 88 --check --diff --verbose
python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/spaces -l 88 --check --diff --verbose
python -m black ./lib/trade_models -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose

View File

@ -11,5 +11,6 @@ from .basic_space import Space
from .basic_space import VirtualNode from .basic_space import VirtualNode
from .basic_op import has_categorical from .basic_op import has_categorical
from .basic_op import has_continuous from .basic_op import has_continuous
from .basic_op import is_determined
from .basic_op import get_min from .basic_op import get_min
from .basic_op import get_max from .basic_op import get_max

View File

@ -19,6 +19,13 @@ def has_continuous(space_or_value, x):
return abs(space_or_value - x) <= _EPS return abs(space_or_value - x) <= _EPS
def is_determined(space_or_value):
if isinstance(space_or_value, Space):
return space_or_value.determined
else:
return True
def get_max(space_or_value): def get_max(space_or_value):
if isinstance(space_or_value, Integer): if isinstance(space_or_value, Integer):
return max(space_or_value.candidates) return max(space_or_value.candidates)

View File

@ -30,10 +30,13 @@ class Space(metaclass=abc.ABCMeta):
def determined(self): def determined(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractproperty
def __repr__(self): def xrepr(self, indent=0):
raise NotImplementedError raise NotImplementedError
def __repr__(self):
return self.xrepr()
@abc.abstractmethod @abc.abstractmethod
def has(self, x): def has(self, x):
"""Check whether x is in this search space.""" """Check whether x is in this search space."""
@ -58,15 +61,28 @@ class VirtualNode(Space):
def has(self, x): def has(self, x):
for key, value in self._attributes.items(): for key, value in self._attributes.items():
if isinstance(value, Space) and value.has(x): if value.has(x):
return True return True
return False return False
def __repr__(self): def append(self, key, value):
strs = [self.__class__.__name__ + "("] if not isinstance(value, Space):
indent = " " * 4 raise ValueError("Invalid type of value: {:}".format(type(value)))
self._attributes[key] = value
def determined(self):
for key, value in self._attributes.items(): for key, value in self._attributes.items():
strs.append(indent + strs(value)) if not value.determined(x):
return False
return True
def random(self, recursion=True):
raise NotImplementedError
def xrepr(self, indent=0):
strs = [self.__class__.__name__ + "("]
for key, value in self._attributes.items():
strs.append(value.xrepr(indent + 2) + ",")
strs.append(")") strs.append(")")
return "\n".join(strs) return "\n".join(strs)
@ -104,10 +120,11 @@ class Categorical(Space):
def __len__(self): def __len__(self):
return len(self._candidates) return len(self._candidates)
def __repr__(self): def xrepr(self, indent=0):
return "{name:}(candidates={cs:}, default_index={default:})".format( xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default name=self.__class__.__name__, cs=self._candidates, default=self._default
) )
return " " * indent + xrepr
def has(self, x): def has(self, x):
super().has(x) super().has(x)
@ -143,13 +160,14 @@ class Integer(Categorical):
default = data.index(default) default = data.index(default)
super(Integer, self).__init__(*data, default=default) super(Integer, self).__init__(*data, default=default)
def __repr__(self): def xrepr(self, indent=0):
return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
lower=self._raw_lower, lower=self._raw_lower,
upper=self._raw_upper, upper=self._raw_upper,
default=self._raw_default, default=self._raw_default,
) )
return " " * indent + xrepr
np_float_types = (np.float16, np.float32, np.float64) np_float_types = (np.float16, np.float32, np.float64)
@ -198,14 +216,15 @@ class Continuous(Space):
def determined(self): def determined(self):
return abs(self.lower - self.upper) <= self._eps return abs(self.lower - self.upper) <= self._eps
def __repr__(self): def xrepr(self, indent=0):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
lower=self._lower, lower=self._lower,
upper=self._upper, upper=self._upper,
default=self._default, default=self._default,
log=self._log_scale, log=self._log_scale,
) )
return " " * indent + xrepr
def convert(self, x): def convert(self, x):
if isinstance(x, np_float_types) and x.size == 1: if isinstance(x, np_float_types) and x.size == 1:

View File

@ -30,7 +30,7 @@ class SuperLinear(SuperModule):
self._in_features = in_features self._in_features = in_features
self._out_features = out_features self._out_features = out_features
self._bias = bias self._bias = bias
# weights to be optimized
self._super_weight = torch.nn.Parameter( self._super_weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features) torch.Tensor(self.out_features, self.in_features)
) )
@ -53,7 +53,14 @@ class SuperLinear(SuperModule):
return spaces.has_categorical(self._bias, True) return spaces.has_categorical(self._bias, True)
def abstract_search_space(self): def abstract_search_space(self):
print('-') root_node = spaces.VirtualNode(id(self))
if not spaces.is_determined(self._in_features):
root_node.append("_in_features", self._in_features)
if not spaces.is_determined(self._out_features):
root_node.append("_out_features", self._out_features)
if not spaces.is_determined(self._bias):
root_node.append("_bias", self._bias)
return root_node
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))

View File

@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path:
from spaces import Categorical from spaces import Categorical
from spaces import Continuous from spaces import Continuous
from spaces import Integer from spaces import Integer
from spaces import Integer from spaces import is_determined
from spaces import get_min from spaces import get_min
from spaces import get_max from spaces import get_max
@ -92,3 +92,7 @@ class TestBasicSpace(unittest.TestCase):
print(nested_space) print(nested_space)
for i in range(1, 13): for i in range(1, 13):
self.assertTrue(nested_space.has(i)) self.assertTrue(nested_space.has(i))
# Test Simple Op
self.assertTrue(is_determined(1))
self.assertFalse(is_determined(nested_space))

View File

@ -26,3 +26,5 @@ class TestSuperLinear(unittest.TestCase):
bias = spaces.Categorical(True, False) bias = spaces.Categorical(True, False)
model = super_core.SuperLinear(10, out_features, bias=bias) model = super_core.SuperLinear(10, out_features, bias=bias)
print(model) print(model)
print(model.super_run_type)
print(model.abstract_search_space())