Add __eq__

This commit is contained in:
D-X-Y 2021-03-19 12:30:32 +08:00
parent ae7136645f
commit b3eed4ca5a
2 changed files with 136 additions and 41 deletions

View File

@ -9,7 +9,7 @@ import random
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from typing import Optional from typing import Optional, Text
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
@ -22,29 +22,37 @@ class Space(metaclass=abc.ABCMeta):
All search space must inherit from this basic class. All search space must inherit from this basic class.
""" """
@abc.abstractproperty
def xrepr(self, indent=0) -> Text:
raise NotImplementedError
def __repr__(self) -> Text:
return self.xrepr()
@abc.abstractproperty
def abstract(self) -> "Space":
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def random(self, recursion=True): def random(self, recursion=True):
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty @abc.abstractproperty
def determined(self): def determined(self) -> bool:
raise NotImplementedError raise NotImplementedError
@abc.abstractproperty
def xrepr(self, indent=0):
raise NotImplementedError
def __repr__(self):
return self.xrepr()
@abc.abstractmethod @abc.abstractmethod
def has(self, x): def has(self, x) -> bool:
"""Check whether x is in this search space.""" """Check whether x is in this search space."""
assert not isinstance( assert not isinstance(
x, Space x, Space
), "The input value itself can not be a search space." ), "The input value itself can not be a search space."
def copy(self): @abc.abstractmethod
def __eq__(self, other):
raise NotImplementedError
def copy(self) -> "Space":
return copy.deepcopy(self) return copy.deepcopy(self)
@ -59,33 +67,56 @@ class VirtualNode(Space):
self._value = value self._value = value
self._attributes = OrderedDict() self._attributes = OrderedDict()
def has(self, x):
for key, value in self._attributes.items():
if value.has(x):
return True
return False
def append(self, key, value): def append(self, key, value):
if not isinstance(value, Space): if not isinstance(value, Space):
raise ValueError("Invalid type of value: {:}".format(type(value))) raise ValueError("Invalid type of value: {:}".format(type(value)))
self._attributes[key] = value self._attributes[key] = value
def determined(self): def xrepr(self, indent=0) -> Text:
for key, value in self._attributes.items():
if not value.determined(x):
return False
return True
def random(self, recursion=True):
raise NotImplementedError
def xrepr(self, indent=0):
strs = [self.__class__.__name__ + "("] strs = [self.__class__.__name__ + "("]
for key, value in self._attributes.items(): for key, value in self._attributes.items():
strs.append(value.xrepr(indent + 2) + ",") strs.append(value.xrepr(indent + 2) + ",")
strs.append(")") strs.append(")")
return "\n".join(strs) return "\n".join(strs)
def abstract(self) -> Space:
node = VirtualNode(id(self))
for key, value in self._attributes.items():
if not value.determined:
node.append(value.abstract())
return node
def random(self, recursion=True):
raise NotImplementedError
def has(self, x) -> bool:
for key, value in self._attributes.items():
if value.has(x):
return True
return False
def __contains__(self, key):
return key in self._attributes
def __getitem__(self, key):
return self._attributes[key]
def determined(self) -> bool:
for key, value in self._attributes.items():
if not value.determined(x):
return False
return True
def __eq__(self, other):
if not isinstance(other, VirtualNode):
return False
for key, value in self._attributes.items():
if not key in other:
return False
if value != other[key]:
return False
return True
class Categorical(Space): class Categorical(Space):
"""A space contains the categorical values. """A space contains the categorical values.
@ -104,6 +135,10 @@ class Categorical(Space):
def candidates(self): def candidates(self):
return self._candidates return self._candidates
@property
def default(self):
return self._default
@property @property
def determined(self): def determined(self):
if len(self) == 1: if len(self) == 1:
@ -120,6 +155,25 @@ class Categorical(Space):
def __len__(self): def __len__(self):
return len(self._candidates) return len(self._candidates)
def abstract(self) -> Space:
if self.determined:
return VirtualNode(id(self), self)
# [TO-IMPROVE]
data = []
for candidate in self.candidates:
if isinstance(candidate, Space):
data.append(candidate.abstract())
else:
data.append(VirtualNode(id(candidate), candidate))
return Categorical(*data, self._default)
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
return sample.random(recursion)
else:
return sample
def xrepr(self, indent=0): def xrepr(self, indent=0):
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default name=self.__class__.__name__, cs=self._candidates, default=self._default
@ -135,12 +189,17 @@ class Categorical(Space):
return True return True
return False return False
def random(self, recursion=True): def __eq__(self, other):
sample = random.choice(self._candidates) if not isinstance(other, Categorical):
if recursion and isinstance(sample, Space): return False
return sample.random(recursion) if len(self) != len(other):
else: return False
return sample if self.default != other.default:
return False
for index in range(len(self)):
if self.__getitem__[index] != other[index]:
return False
return True
class Integer(Categorical): class Integer(Categorical):
@ -213,8 +272,23 @@ class Continuous(Space):
return self._default return self._default
@property @property
def determined(self): def use_log(self):
return abs(self.lower - self.upper) <= self._eps return self._log_scale
@property
def eps(self):
return self._eps
def abstract(self) -> Space:
return self.copy()
def random(self, recursion=True):
del recursion
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample)
else:
return random.uniform(self._lower, self._upper)
def xrepr(self, indent=0): def xrepr(self, indent=0):
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
@ -243,10 +317,20 @@ class Continuous(Space):
converted_x, success = self.convert(x) converted_x, success = self.convert(x)
return success and self.lower <= converted_x <= self.upper return success and self.lower <= converted_x <= self.upper
def random(self, recursion=True): @property
del recursion def determined(self):
if self._log_scale: return abs(self.lower - self.upper) <= self._eps
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample) def __eq__(self, other):
if not isinstance(other, Continuous):
return False
if self is other:
return True
else: else:
return random.uniform(self._lower, self._upper) return (
self.lower == other.lower
and self.upper == other.upper
and self.default == other.default
and self.use_log == other.use_log
and self.eps == other.eps
)

View File

@ -96,3 +96,14 @@ class TestBasicSpace(unittest.TestCase):
# Test Simple Op # Test Simple Op
self.assertTrue(is_determined(1)) self.assertTrue(is_determined(1))
self.assertFalse(is_determined(nested_space)) self.assertFalse(is_determined(nested_space))
class TestAbstractSpace(unittest.TestCase):
"""Test the abstract search spaces."""
def test_continous(self):
space = Continuous(0, 1)
self.assertEqual(space, space.abstract())
print(space.abstract())