Add simple spaces
This commit is contained in:
		| @@ -11,7 +11,7 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben | ||||
|  | ||||
| ## Requirements | ||||
|  | ||||
| - `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) | ||||
| - `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) | ||||
| - `hpbandster` : if you want to run BOHB | ||||
|  | ||||
| ## Citation | ||||
|   | ||||
| @@ -5,3 +5,4 @@ | ||||
| ##################################################### | ||||
|  | ||||
| from .basic_space import Categorical | ||||
| from .basic_space import Continuous | ||||
|   | ||||
| @@ -3,12 +3,15 @@ | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| import math | ||||
| import random | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
|  | ||||
| class Space(metaclass=abc.ABCMeta): | ||||
|     @abc.abstractmethod | ||||
|     def random(self): | ||||
|     def random(self, recursion=True): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
| @@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta): | ||||
|  | ||||
|  | ||||
| class Categorical(Space): | ||||
|     def __init__(self, *data): | ||||
|     def __init__(self, *data, default: Optional[int] = None): | ||||
|         self._candidates = [*data] | ||||
|         self._default = default | ||||
|         assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format( | ||||
|             len(self._candidates) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self._candidates[index] | ||||
| @@ -27,7 +34,50 @@ class Categorical(Space): | ||||
|         return len(self._candidates) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates) | ||||
|         return "{name:}(candidates={cs:}, default_index={default:})".format( | ||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||
|         ) | ||||
|  | ||||
|     def random(self): | ||||
|         return random.choice(self._candidates) | ||||
|     def random(self, recursion=True): | ||||
|         sample = random.choice(self._candidates) | ||||
|         if recursion and isinstance(sample, Space): | ||||
|             return sample.random(recursion) | ||||
|         else: | ||||
|             return sample | ||||
|  | ||||
|  | ||||
| class Continuous(Space): | ||||
|     def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False): | ||||
|         self._lower = lower | ||||
|         self._upper = upper | ||||
|         self._default = default | ||||
|         self._log_scale = log | ||||
|  | ||||
|     @property | ||||
|     def lower(self): | ||||
|         return self._lower | ||||
|  | ||||
|     @property | ||||
|     def upper(self): | ||||
|         return self._upper | ||||
|  | ||||
|     @property | ||||
|     def default(self): | ||||
|         return self._default | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             lower=self._lower, | ||||
|             upper=self._upper, | ||||
|             default=self._default, | ||||
|             log=self._log_scale, | ||||
|         ) | ||||
|  | ||||
|     def random(self, recursion=True): | ||||
|         del recursion | ||||
|         if self._log_scale: | ||||
|             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||
|             return math.exp(sample) | ||||
|         else: | ||||
|             return random.uniform(self._lower, self._upper) | ||||
|   | ||||
| @@ -12,6 +12,7 @@ if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from spaces import Categorical | ||||
| from spaces import Continuous | ||||
|  | ||||
|  | ||||
| class TestBasicSpace(unittest.TestCase): | ||||
| @@ -19,4 +20,19 @@ class TestBasicSpace(unittest.TestCase): | ||||
|         space = Categorical(1, 2, 3, 4) | ||||
|         for i in range(4): | ||||
|             self.assertEqual(space[i], i + 1) | ||||
|         self.assertEqual("Categorical(candidates=[1, 2, 3, 4])", str(space)) | ||||
|         self.assertEqual("Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space)) | ||||
|  | ||||
|     def test_continuous(self): | ||||
|         space = Continuous(0, 1) | ||||
|         self.assertGreaterEqual(space.random(), 0) | ||||
|         self.assertGreaterEqual(1, space.random()) | ||||
|  | ||||
|         lower, upper = 1.5, 4.6 | ||||
|         space = Continuous(lower, upper, log=False) | ||||
|         values = [] | ||||
|         for i in range(100000): | ||||
|             x = space.random() | ||||
|             self.assertGreaterEqual(x, lower) | ||||
|             self.assertGreaterEqual(upper, x) | ||||
|             values.append(x) | ||||
|         self.assertAlmostEqual((lower + upper) / 2, sum(values) / len(values), places=2) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user