diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 3d01193..17e2fa5 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -32,7 +32,7 @@ jobs: echo $PWD ; ls python -m black ./exps -l 88 --check --diff --verbose python -m black ./tests -l 88 --check --diff --verbose - python -m black ./lib/layers -l 88 --check --diff --verbose + python -m black ./lib/xlayers -l 88 --check --diff --verbose python -m black ./lib/spaces -l 88 --check --diff --verbose python -m black ./lib/trade_models -l 88 --check --diff --verbose diff --git a/lib/spaces/__init__.py b/lib/spaces/__init__.py index eb422cf..9cfe5b1 100644 --- a/lib/spaces/__init__.py +++ b/lib/spaces/__init__.py @@ -11,5 +11,6 @@ from .basic_space import Space from .basic_space import VirtualNode from .basic_op import has_categorical from .basic_op import has_continuous +from .basic_op import is_determined from .basic_op import get_min from .basic_op import get_max diff --git a/lib/spaces/basic_op.py b/lib/spaces/basic_op.py index 2ba999a..fbb75b3 100644 --- a/lib/spaces/basic_op.py +++ b/lib/spaces/basic_op.py @@ -19,6 +19,13 @@ def has_continuous(space_or_value, x): return abs(space_or_value - x) <= _EPS +def is_determined(space_or_value): + if isinstance(space_or_value, Space): + return space_or_value.determined + else: + return True + + def get_max(space_or_value): if isinstance(space_or_value, Integer): return max(space_or_value.candidates) diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index 2f45de5..9bf707b 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -30,10 +30,13 @@ class Space(metaclass=abc.ABCMeta): def determined(self): raise NotImplementedError - @abc.abstractmethod - def __repr__(self): + @abc.abstractproperty + def xrepr(self, indent=0): raise NotImplementedError + def __repr__(self): + return self.xrepr() + @abc.abstractmethod def has(self, x): """Check whether x is in this search space.""" @@ -58,15 +61,28 @@ class VirtualNode(Space): def has(self, x): for key, value in self._attributes.items(): - if isinstance(value, Space) and value.has(x): + if value.has(x): return True return False - def __repr__(self): - strs = [self.__class__.__name__ + "("] - indent = " " * 4 + def append(self, key, value): + if not isinstance(value, Space): + raise ValueError("Invalid type of value: {:}".format(type(value))) + self._attributes[key] = value + + def determined(self): for key, value in self._attributes.items(): - strs.append(indent + strs(value)) + if not value.determined(x): + return False + return True + + def random(self, recursion=True): + raise NotImplementedError + + def xrepr(self, indent=0): + strs = [self.__class__.__name__ + "("] + for key, value in self._attributes.items(): + strs.append(value.xrepr(indent + 2) + ",") strs.append(")") return "\n".join(strs) @@ -104,10 +120,11 @@ class Categorical(Space): def __len__(self): return len(self._candidates) - def __repr__(self): - return "{name:}(candidates={cs:}, default_index={default:})".format( + def xrepr(self, indent=0): + xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( name=self.__class__.__name__, cs=self._candidates, default=self._default ) + return " " * indent + xrepr def has(self, x): super().has(x) @@ -143,13 +160,14 @@ class Integer(Categorical): default = data.index(default) super(Integer, self).__init__(*data, default=default) - def __repr__(self): - return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( + def xrepr(self, indent=0): + xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( name=self.__class__.__name__, lower=self._raw_lower, upper=self._raw_upper, default=self._raw_default, ) + return " " * indent + xrepr np_float_types = (np.float16, np.float32, np.float64) @@ -198,14 +216,15 @@ class Continuous(Space): def determined(self): return abs(self.lower - self.upper) <= self._eps - def __repr__(self): - return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( + def xrepr(self, indent=0): + xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( name=self.__class__.__name__, lower=self._lower, upper=self._upper, default=self._default, log=self._log_scale, ) + return " " * indent + xrepr def convert(self, x): if isinstance(x, np_float_types) and x.size == 1: diff --git a/lib/xlayers/super_mlp.py b/lib/xlayers/super_mlp.py index c3ed3e7..3d79b1b 100644 --- a/lib/xlayers/super_mlp.py +++ b/lib/xlayers/super_mlp.py @@ -30,7 +30,7 @@ class SuperLinear(SuperModule): self._in_features = in_features self._out_features = out_features self._bias = bias - + # weights to be optimized self._super_weight = torch.nn.Parameter( torch.Tensor(self.out_features, self.in_features) ) @@ -53,7 +53,14 @@ class SuperLinear(SuperModule): return spaces.has_categorical(self._bias, True) def abstract_search_space(self): - print('-') + root_node = spaces.VirtualNode(id(self)) + if not spaces.is_determined(self._in_features): + root_node.append("_in_features", self._in_features) + if not spaces.is_determined(self._out_features): + root_node.append("_out_features", self._out_features) + if not spaces.is_determined(self._bias): + root_node.append("_bias", self._bias) + return root_node def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) diff --git a/tests/test_basic_space.py b/tests/test_basic_space.py index c02cddd..983514c 100644 --- a/tests/test_basic_space.py +++ b/tests/test_basic_space.py @@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path: from spaces import Categorical from spaces import Continuous from spaces import Integer -from spaces import Integer +from spaces import is_determined from spaces import get_min from spaces import get_max @@ -92,3 +92,7 @@ class TestBasicSpace(unittest.TestCase): print(nested_space) for i in range(1, 13): self.assertTrue(nested_space.has(i)) + + # Test Simple Op + self.assertTrue(is_determined(1)) + self.assertFalse(is_determined(nested_space)) diff --git a/tests/test_super_model.py b/tests/test_super_model.py index 8b6c207..c117363 100644 --- a/tests/test_super_model.py +++ b/tests/test_super_model.py @@ -26,3 +26,5 @@ class TestSuperLinear(unittest.TestCase): bias = spaces.Categorical(True, False) model = super_core.SuperLinear(10, out_features, bias=bias) print(model) + print(model.super_run_type) + print(model.abstract_search_space())