Add simple spaces
This commit is contained in:
		| @@ -11,7 +11,7 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben | |||||||
|  |  | ||||||
| ## Requirements | ## Requirements | ||||||
|  |  | ||||||
| - `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) | - `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) | ||||||
| - `hpbandster` : if you want to run BOHB | - `hpbandster` : if you want to run BOHB | ||||||
|  |  | ||||||
| ## Citation | ## Citation | ||||||
|   | |||||||
| @@ -5,3 +5,4 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| from .basic_space import Categorical | from .basic_space import Categorical | ||||||
|  | from .basic_space import Continuous | ||||||
|   | |||||||
| @@ -3,12 +3,15 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import abc | import abc | ||||||
|  | import math | ||||||
| import random | import random | ||||||
|  |  | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  |  | ||||||
| class Space(metaclass=abc.ABCMeta): | class Space(metaclass=abc.ABCMeta): | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def random(self): |     def random(self, recursion=True): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
| @@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Categorical(Space): | class Categorical(Space): | ||||||
|     def __init__(self, *data): |     def __init__(self, *data, default: Optional[int] = None): | ||||||
|         self._candidates = [*data] |         self._candidates = [*data] | ||||||
|  |         self._default = default | ||||||
|  |         assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format( | ||||||
|  |             len(self._candidates) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         return self._candidates[index] |         return self._candidates[index] | ||||||
| @@ -27,7 +34,50 @@ class Categorical(Space): | |||||||
|         return len(self._candidates) |         return len(self._candidates) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates) |         return "{name:}(candidates={cs:}, default_index={default:})".format( | ||||||
|  |             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def random(self): |     def random(self, recursion=True): | ||||||
|         return random.choice(self._candidates) |         sample = random.choice(self._candidates) | ||||||
|  |         if recursion and isinstance(sample, Space): | ||||||
|  |             return sample.random(recursion) | ||||||
|  |         else: | ||||||
|  |             return sample | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Continuous(Space): | ||||||
|  |     def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False): | ||||||
|  |         self._lower = lower | ||||||
|  |         self._upper = upper | ||||||
|  |         self._default = default | ||||||
|  |         self._log_scale = log | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def lower(self): | ||||||
|  |         return self._lower | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def upper(self): | ||||||
|  |         return self._upper | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def default(self): | ||||||
|  |         return self._default | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             lower=self._lower, | ||||||
|  |             upper=self._upper, | ||||||
|  |             default=self._default, | ||||||
|  |             log=self._log_scale, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True): | ||||||
|  |         del recursion | ||||||
|  |         if self._log_scale: | ||||||
|  |             sample = random.uniform(math.log(self._lower), math.log(self._upper)) | ||||||
|  |             return math.exp(sample) | ||||||
|  |         else: | ||||||
|  |             return random.uniform(self._lower, self._upper) | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ if str(lib_dir) not in sys.path: | |||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from spaces import Categorical | from spaces import Categorical | ||||||
|  | from spaces import Continuous | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestBasicSpace(unittest.TestCase): | class TestBasicSpace(unittest.TestCase): | ||||||
| @@ -19,4 +20,19 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|         space = Categorical(1, 2, 3, 4) |         space = Categorical(1, 2, 3, 4) | ||||||
|         for i in range(4): |         for i in range(4): | ||||||
|             self.assertEqual(space[i], i + 1) |             self.assertEqual(space[i], i + 1) | ||||||
|         self.assertEqual("Categorical(candidates=[1, 2, 3, 4])", str(space)) |         self.assertEqual("Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space)) | ||||||
|  |  | ||||||
|  |     def test_continuous(self): | ||||||
|  |         space = Continuous(0, 1) | ||||||
|  |         self.assertGreaterEqual(space.random(), 0) | ||||||
|  |         self.assertGreaterEqual(1, space.random()) | ||||||
|  |  | ||||||
|  |         lower, upper = 1.5, 4.6 | ||||||
|  |         space = Continuous(lower, upper, log=False) | ||||||
|  |         values = [] | ||||||
|  |         for i in range(100000): | ||||||
|  |             x = space.random() | ||||||
|  |             self.assertGreaterEqual(x, lower) | ||||||
|  |             self.assertGreaterEqual(upper, x) | ||||||
|  |             values.append(x) | ||||||
|  |         self.assertAlmostEqual((lower + upper) / 2, sum(values) / len(values), places=2) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user