Add __eq__
This commit is contained in:
		| @@ -9,7 +9,7 @@ import random | ||||
| import numpy as np | ||||
| from collections import OrderedDict | ||||
|  | ||||
| from typing import Optional | ||||
| from typing import Optional, Text | ||||
|  | ||||
|  | ||||
| __all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"] | ||||
| @@ -22,29 +22,37 @@ class Space(metaclass=abc.ABCMeta): | ||||
|     All search space must inherit from this basic class. | ||||
|     """ | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, indent=0) -> Text: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self) -> Text: | ||||
|         return self.xrepr() | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def abstract(self) -> "Space": | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def random(self, recursion=True): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def determined(self): | ||||
|     def determined(self) -> bool: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, indent=0): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return self.xrepr() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def has(self, x): | ||||
|     def has(self, x) -> bool: | ||||
|         """Check whether x is in this search space.""" | ||||
|         assert not isinstance( | ||||
|             x, Space | ||||
|         ), "The input value itself can not be a search space." | ||||
|  | ||||
|     def copy(self): | ||||
|     @abc.abstractmethod | ||||
|     def __eq__(self, other): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def copy(self) -> "Space": | ||||
|         return copy.deepcopy(self) | ||||
|  | ||||
|  | ||||
| @@ -59,33 +67,56 @@ class VirtualNode(Space): | ||||
|         self._value = value | ||||
|         self._attributes = OrderedDict() | ||||
|  | ||||
|     def has(self, x): | ||||
|         for key, value in self._attributes.items(): | ||||
|             if value.has(x): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def append(self, key, value): | ||||
|         if not isinstance(value, Space): | ||||
|             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||
|         self._attributes[key] = value | ||||
|  | ||||
|     def determined(self): | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined(x): | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|     def xrepr(self, indent=0) -> Text: | ||||
|         strs = [self.__class__.__name__ + "("] | ||||
|         for key, value in self._attributes.items(): | ||||
|             strs.append(value.xrepr(indent + 2) + ",") | ||||
|         strs.append(")") | ||||
|         return "\n".join(strs) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         node = VirtualNode(id(self)) | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined: | ||||
|                 node.append(value.abstract()) | ||||
|         return node | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def has(self, x) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if value.has(x): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def __contains__(self, key): | ||||
|         return key in self._attributes | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         return self._attributes[key] | ||||
|  | ||||
|     def determined(self) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined(x): | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, VirtualNode): | ||||
|             return False | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not key in other: | ||||
|                 return False | ||||
|             if value != other[key]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class Categorical(Space): | ||||
|     """A space contains the categorical values. | ||||
| @@ -104,6 +135,10 @@ class Categorical(Space): | ||||
|     def candidates(self): | ||||
|         return self._candidates | ||||
|  | ||||
|     @property | ||||
|     def default(self): | ||||
|         return self._default | ||||
|  | ||||
|     @property | ||||
|     def determined(self): | ||||
|         if len(self) == 1: | ||||
| @@ -120,6 +155,25 @@ class Categorical(Space): | ||||
|     def __len__(self): | ||||
|         return len(self._candidates) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         if self.determined: | ||||
|             return VirtualNode(id(self), self) | ||||
|         # [TO-IMPROVE] | ||||
|         data = [] | ||||
|         for candidate in self.candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 data.append(candidate.abstract()) | ||||
|             else: | ||||
|                 data.append(VirtualNode(id(candidate), candidate)) | ||||
|         return Categorical(*data, self._default) | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             return sample.random(recursion) | ||||
|         else: | ||||
|             return sample | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
| @@ -135,12 +189,17 @@ class Categorical(Space): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             return sample.random(recursion) | ||||
|         else: | ||||
|             return sample | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Categorical): | ||||
|             return False | ||||
|         if len(self) != len(other): | ||||
|             return False | ||||
|         if self.default != other.default: | ||||
|             return False | ||||
|         for index in range(len(self)): | ||||
|             if self.__getitem__[index] != other[index]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class Integer(Categorical): | ||||
| @@ -213,8 +272,23 @@ class Continuous(Space): | ||||
|         return self._default | ||||
|  | ||||
|     @property | ||||
|     def determined(self): | ||||
|         return abs(self.lower - self.upper) <= self._eps | ||||
|     def use_log(self): | ||||
|         return self._log_scale | ||||
|  | ||||
|     @property | ||||
|     def eps(self): | ||||
|         return self._eps | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         return self.copy() | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         del recursion | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             return math.exp(sample) | ||||
|         else: | ||||
|             return random.uniform(self._lower, self._upper) | ||||
|  | ||||
|     def xrepr(self, indent=0): | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
| @@ -243,10 +317,20 @@ class Continuous(Space): | ||||
|         converted_x, success = self.convert(x) | ||||
|         return success and self.lower <= converted_x <= self.upper | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         del recursion | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             return math.exp(sample) | ||||
|     @property | ||||
|     def determined(self): | ||||
|         return abs(self.lower - self.upper) <= self._eps | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Continuous): | ||||
|             return False | ||||
|         if self is other: | ||||
|             return True | ||||
|         else: | ||||
|             return random.uniform(self._lower, self._upper) | ||||
|             return ( | ||||
|                 self.lower == other.lower | ||||
|                 and self.upper == other.upper | ||||
|                 and self.default == other.default | ||||
|                 and self.use_log == other.use_log | ||||
|                 and self.eps == other.eps | ||||
|             ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user