Add unit tests for super-linear
This commit is contained in:
parent
badb6cf51d
commit
ca22d61259
2
.github/workflows/basic_test.yml
vendored
2
.github/workflows/basic_test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
|||||||
echo $PWD ; ls
|
echo $PWD ; ls
|
||||||
python -m black ./exps -l 88 --check --diff --verbose
|
python -m black ./exps -l 88 --check --diff --verbose
|
||||||
python -m black ./tests -l 88 --check --diff --verbose
|
python -m black ./tests -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/layers -l 88 --check --diff --verbose
|
python -m black ./lib/xlayers -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
||||||
|
|
||||||
|
@ -11,5 +11,6 @@ from .basic_space import Space
|
|||||||
from .basic_space import VirtualNode
|
from .basic_space import VirtualNode
|
||||||
from .basic_op import has_categorical
|
from .basic_op import has_categorical
|
||||||
from .basic_op import has_continuous
|
from .basic_op import has_continuous
|
||||||
|
from .basic_op import is_determined
|
||||||
from .basic_op import get_min
|
from .basic_op import get_min
|
||||||
from .basic_op import get_max
|
from .basic_op import get_max
|
||||||
|
@ -19,6 +19,13 @@ def has_continuous(space_or_value, x):
|
|||||||
return abs(space_or_value - x) <= _EPS
|
return abs(space_or_value - x) <= _EPS
|
||||||
|
|
||||||
|
|
||||||
|
def is_determined(space_or_value):
|
||||||
|
if isinstance(space_or_value, Space):
|
||||||
|
return space_or_value.determined
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_max(space_or_value):
|
def get_max(space_or_value):
|
||||||
if isinstance(space_or_value, Integer):
|
if isinstance(space_or_value, Integer):
|
||||||
return max(space_or_value.candidates)
|
return max(space_or_value.candidates)
|
||||||
|
@ -30,10 +30,13 @@ class Space(metaclass=abc.ABCMeta):
|
|||||||
def determined(self):
|
def determined(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractproperty
|
||||||
def __repr__(self):
|
def xrepr(self, indent=0):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.xrepr()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def has(self, x):
|
def has(self, x):
|
||||||
"""Check whether x is in this search space."""
|
"""Check whether x is in this search space."""
|
||||||
@ -58,15 +61,28 @@ class VirtualNode(Space):
|
|||||||
|
|
||||||
def has(self, x):
|
def has(self, x):
|
||||||
for key, value in self._attributes.items():
|
for key, value in self._attributes.items():
|
||||||
if isinstance(value, Space) and value.has(x):
|
if value.has(x):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __repr__(self):
|
def append(self, key, value):
|
||||||
strs = [self.__class__.__name__ + "("]
|
if not isinstance(value, Space):
|
||||||
indent = " " * 4
|
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||||
|
self._attributes[key] = value
|
||||||
|
|
||||||
|
def determined(self):
|
||||||
for key, value in self._attributes.items():
|
for key, value in self._attributes.items():
|
||||||
strs.append(indent + strs(value))
|
if not value.determined(x):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def random(self, recursion=True):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def xrepr(self, indent=0):
|
||||||
|
strs = [self.__class__.__name__ + "("]
|
||||||
|
for key, value in self._attributes.items():
|
||||||
|
strs.append(value.xrepr(indent + 2) + ",")
|
||||||
strs.append(")")
|
strs.append(")")
|
||||||
return "\n".join(strs)
|
return "\n".join(strs)
|
||||||
|
|
||||||
@ -104,10 +120,11 @@ class Categorical(Space):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._candidates)
|
return len(self._candidates)
|
||||||
|
|
||||||
def __repr__(self):
|
def xrepr(self, indent=0):
|
||||||
return "{name:}(candidates={cs:}, default_index={default:})".format(
|
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||||
)
|
)
|
||||||
|
return " " * indent + xrepr
|
||||||
|
|
||||||
def has(self, x):
|
def has(self, x):
|
||||||
super().has(x)
|
super().has(x)
|
||||||
@ -143,13 +160,14 @@ class Integer(Categorical):
|
|||||||
default = data.index(default)
|
default = data.index(default)
|
||||||
super(Integer, self).__init__(*data, default=default)
|
super(Integer, self).__init__(*data, default=default)
|
||||||
|
|
||||||
def __repr__(self):
|
def xrepr(self, indent=0):
|
||||||
return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
lower=self._raw_lower,
|
lower=self._raw_lower,
|
||||||
upper=self._raw_upper,
|
upper=self._raw_upper,
|
||||||
default=self._raw_default,
|
default=self._raw_default,
|
||||||
)
|
)
|
||||||
|
return " " * indent + xrepr
|
||||||
|
|
||||||
|
|
||||||
np_float_types = (np.float16, np.float32, np.float64)
|
np_float_types = (np.float16, np.float32, np.float64)
|
||||||
@ -198,14 +216,15 @@ class Continuous(Space):
|
|||||||
def determined(self):
|
def determined(self):
|
||||||
return abs(self.lower - self.upper) <= self._eps
|
return abs(self.lower - self.upper) <= self._eps
|
||||||
|
|
||||||
def __repr__(self):
|
def xrepr(self, indent=0):
|
||||||
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
lower=self._lower,
|
lower=self._lower,
|
||||||
upper=self._upper,
|
upper=self._upper,
|
||||||
default=self._default,
|
default=self._default,
|
||||||
log=self._log_scale,
|
log=self._log_scale,
|
||||||
)
|
)
|
||||||
|
return " " * indent + xrepr
|
||||||
|
|
||||||
def convert(self, x):
|
def convert(self, x):
|
||||||
if isinstance(x, np_float_types) and x.size == 1:
|
if isinstance(x, np_float_types) and x.size == 1:
|
||||||
|
@ -30,7 +30,7 @@ class SuperLinear(SuperModule):
|
|||||||
self._in_features = in_features
|
self._in_features = in_features
|
||||||
self._out_features = out_features
|
self._out_features = out_features
|
||||||
self._bias = bias
|
self._bias = bias
|
||||||
|
# weights to be optimized
|
||||||
self._super_weight = torch.nn.Parameter(
|
self._super_weight = torch.nn.Parameter(
|
||||||
torch.Tensor(self.out_features, self.in_features)
|
torch.Tensor(self.out_features, self.in_features)
|
||||||
)
|
)
|
||||||
@ -53,7 +53,14 @@ class SuperLinear(SuperModule):
|
|||||||
return spaces.has_categorical(self._bias, True)
|
return spaces.has_categorical(self._bias, True)
|
||||||
|
|
||||||
def abstract_search_space(self):
|
def abstract_search_space(self):
|
||||||
print('-')
|
root_node = spaces.VirtualNode(id(self))
|
||||||
|
if not spaces.is_determined(self._in_features):
|
||||||
|
root_node.append("_in_features", self._in_features)
|
||||||
|
if not spaces.is_determined(self._out_features):
|
||||||
|
root_node.append("_out_features", self._out_features)
|
||||||
|
if not spaces.is_determined(self._bias):
|
||||||
|
root_node.append("_bias", self._bias)
|
||||||
|
return root_node
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5))
|
||||||
|
@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path:
|
|||||||
from spaces import Categorical
|
from spaces import Categorical
|
||||||
from spaces import Continuous
|
from spaces import Continuous
|
||||||
from spaces import Integer
|
from spaces import Integer
|
||||||
from spaces import Integer
|
from spaces import is_determined
|
||||||
from spaces import get_min
|
from spaces import get_min
|
||||||
from spaces import get_max
|
from spaces import get_max
|
||||||
|
|
||||||
@ -92,3 +92,7 @@ class TestBasicSpace(unittest.TestCase):
|
|||||||
print(nested_space)
|
print(nested_space)
|
||||||
for i in range(1, 13):
|
for i in range(1, 13):
|
||||||
self.assertTrue(nested_space.has(i))
|
self.assertTrue(nested_space.has(i))
|
||||||
|
|
||||||
|
# Test Simple Op
|
||||||
|
self.assertTrue(is_determined(1))
|
||||||
|
self.assertFalse(is_determined(nested_space))
|
||||||
|
@ -26,3 +26,5 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
bias = spaces.Categorical(True, False)
|
bias = spaces.Categorical(True, False)
|
||||||
model = super_core.SuperLinear(10, out_features, bias=bias)
|
model = super_core.SuperLinear(10, out_features, bias=bias)
|
||||||
print(model)
|
print(model)
|
||||||
|
print(model.super_run_type)
|
||||||
|
print(model.abstract_search_space())
|
||||||
|
Loading…
Reference in New Issue
Block a user