Add the SuperMLP class
This commit is contained in:
		| @@ -22,19 +22,32 @@ class Space(metaclass=abc.ABCMeta): | ||||
|     All search space must inherit from this basic class. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         # used to avoid duplicate sample | ||||
|         self._last_sample = None | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def xrepr(self, prefix="") -> Text: | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self) -> Text: | ||||
|         return self.xrepr() | ||||
|  | ||||
|     @abc.abstractproperty | ||||
|     def abstract(self) -> "Space": | ||||
|     def abstract(self, reuse_last=False) -> "Space": | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_sample(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def clean_last_abstract(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractproperty | ||||
| @@ -63,6 +76,7 @@ class VirtualNode(Space): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, id=None, value=None): | ||||
|         super(VirtualNode, self).__init__() | ||||
|         self._id = id | ||||
|         self._value = value | ||||
|         self._attributes = OrderedDict() | ||||
| @@ -82,26 +96,51 @@ class VirtualNode(Space): | ||||
|         #    raise ValueError("Can not attach a determined value: {:}".format(value)) | ||||
|         self._attributes[key] = value | ||||
|  | ||||
|     def xrepr(self, prefix="  ") -> Text: | ||||
|     def xrepr(self, depth=0) -> Text: | ||||
|         strs = [self.__class__.__name__ + "(value={:}".format(self._value)] | ||||
|         for key, value in self._attributes.items(): | ||||
|             strs.append(value.xrepr(prefix + "  " + key + " = ")) | ||||
|             strs.append(key + " = " + value.xrepr(depth + 1)) | ||||
|         strs.append(")") | ||||
|         return prefix + "".join(strs) if len(strs) == 2 else ",\n".join(strs) | ||||
|         if len(strs) == 2: | ||||
|             return "".join(strs) | ||||
|         else: | ||||
|             space = "  " | ||||
|             xstrs = ( | ||||
|                 [strs[0]] | ||||
|                 + [space * (depth + 1) + x for x in strs[1:-1]] | ||||
|                 + [space * depth + strs[-1]] | ||||
|             ) | ||||
|             return ",\n".join(xstrs) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         node = VirtualNode(id(self)) | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined: | ||||
|                 node.append(value.abstract()) | ||||
|         return node | ||||
|                 node.append(value.abstract(reuse_last)) | ||||
|         self._last_abstract = node | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         node = VirtualNode(None, self._value) | ||||
|         for key, value in self._attributes.items(): | ||||
|             node.append(key, value.random(recursion)) | ||||
|             node.append(key, value.random(recursion, reuse_last)) | ||||
|         self._last_sample = node  # record the last sample | ||||
|         return node | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_sample() | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for key, value in self._attributes.items(): | ||||
|             value.clean_last_abstract() | ||||
|  | ||||
|     def has(self, x) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if value.has(x): | ||||
| @@ -117,7 +156,7 @@ class VirtualNode(Space): | ||||
|     @property | ||||
|     def determined(self) -> bool: | ||||
|         for key, value in self._attributes.items(): | ||||
|             if not value.determined(x): | ||||
|             if not value.determined: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
| @@ -138,6 +177,7 @@ class Categorical(Space): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *data, default: Optional[int] = None): | ||||
|         super(Categorical, self).__init__() | ||||
|         self._candidates = [*data] | ||||
|         self._default = default | ||||
|         assert self._default is None or 0 <= self._default < len( | ||||
| @@ -169,32 +209,54 @@ class Categorical(Space): | ||||
|     def __len__(self): | ||||
|         return len(self._candidates) | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         if self.determined: | ||||
|             return VirtualNode(id(self), self) | ||||
|         # [TO-IMPROVE] | ||||
|         data = [] | ||||
|         for candidate in self.candidates: | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 data.append(candidate.abstract()) | ||||
|             else: | ||||
|                 data.append(VirtualNode(id(candidate), candidate)) | ||||
|         return Categorical(*data, default=self._default) | ||||
|                 candidate.clean_last_sample() | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|         for candidate in self._candidates: | ||||
|             if isinstance(candidate, Space): | ||||
|                 candidate.clean_last_abstract() | ||||
|  | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         if self.determined: | ||||
|             result = VirtualNode(id(self), self) | ||||
|         else: | ||||
|             # [TO-IMPROVE] | ||||
|             data = [] | ||||
|             for candidate in self.candidates: | ||||
|                 if isinstance(candidate, Space): | ||||
|                     data.append(candidate.abstract()) | ||||
|                 else: | ||||
|                     data.append(VirtualNode(id(candidate), candidate)) | ||||
|             result = Categorical(*data, default=self._default) | ||||
|         self._last_abstract = result | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             sample = sample.random(recursion) | ||||
|             sample = sample.random(recursion, reuse_last) | ||||
|         if isinstance(sample, VirtualNode): | ||||
|             return sample.copy() | ||||
|             sample = sample.copy() | ||||
|         else: | ||||
|             return VirtualNode(None, sample) | ||||
|             sample = VirtualNode(None, sample) | ||||
|         self._last_sample = sample | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|     def has(self, x): | ||||
|         super().has(x) | ||||
| @@ -213,7 +275,7 @@ class Categorical(Space): | ||||
|         if self.default != other.default: | ||||
|             return False | ||||
|         for index in range(len(self)): | ||||
|             if self.__getitem__[index] != other[index]: | ||||
|             if self.__getitem__(index) != other[index]: | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
| @@ -235,14 +297,15 @@ class Integer(Categorical): | ||||
|             default = data.index(default) | ||||
|         super(Integer, self).__init__(*data, default=default) | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._raw_lower, | ||||
|             upper=self._raw_upper, | ||||
|             default=self._raw_default, | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|  | ||||
| np_float_types = (np.float16, np.float32, np.float64) | ||||
| @@ -269,6 +332,7 @@ class Continuous(Space): | ||||
|         log: bool = False, | ||||
|         eps: float = _EPS, | ||||
|     ): | ||||
|         super(Continuous, self).__init__() | ||||
|         self._lower = lower | ||||
|         self._upper = upper | ||||
|         self._default = default | ||||
| @@ -295,19 +359,26 @@ class Continuous(Space): | ||||
|     def eps(self): | ||||
|         return self._eps | ||||
|  | ||||
|     def abstract(self) -> Space: | ||||
|         return self.copy() | ||||
|     def abstract(self, reuse_last=False) -> Space: | ||||
|         if reuse_last and self._last_abstract is not None: | ||||
|             return self._last_abstract | ||||
|         self._last_abstract = self.copy() | ||||
|         return self._last_abstract | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|     def random(self, recursion=True, reuse_last=False): | ||||
|         del recursion | ||||
|         if reuse_last and self._last_sample is not None: | ||||
|             return self._last_sample | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             sample = math.exp(sample) | ||||
|         else: | ||||
|             sample = random.uniform(self._lower, self._upper) | ||||
|         return VirtualNode(None, sample) | ||||
|         self._last_sample = VirtualNode(None, sample) | ||||
|         return self._last_sample | ||||
|  | ||||
|     def xrepr(self, prefix=""): | ||||
|     def xrepr(self, depth=0): | ||||
|         del depth | ||||
|         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._lower, | ||||
| @@ -315,7 +386,7 @@ class Continuous(Space): | ||||
|             default=self._default, | ||||
|             log=self._log_scale, | ||||
|         ) | ||||
|         return prefix + xrepr | ||||
|         return xrepr | ||||
|  | ||||
|     def convert(self, x): | ||||
|         if isinstance(x, np_float_types) and x.size == 1: | ||||
| @@ -338,6 +409,12 @@ class Continuous(Space): | ||||
|     def determined(self): | ||||
|         return abs(self.lower - self.upper) <= self._eps | ||||
|  | ||||
|     def clean_last_sample(self): | ||||
|         self._last_sample = None | ||||
|  | ||||
|     def clean_last_abstract(self): | ||||
|         self._last_abstract = None | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if not isinstance(other, Continuous): | ||||
|             return False | ||||
|   | ||||
		Reference in New Issue
	
	Block a user