diff --git a/exps/NATS-algos/README.md b/exps/NATS-algos/README.md index 361ce4f..b4b2e39 100644 --- a/exps/NATS-algos/README.md +++ b/exps/NATS-algos/README.md @@ -11,7 +11,7 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben ## Requirements -- `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) +- `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) - `hpbandster` : if you want to run BOHB ## Citation diff --git a/lib/spaces/__init__.py b/lib/spaces/__init__.py index 5550320..7b4614f 100644 --- a/lib/spaces/__init__.py +++ b/lib/spaces/__init__.py @@ -5,3 +5,4 @@ ##################################################### from .basic_space import Categorical +from .basic_space import Continuous diff --git a/lib/spaces/basic_space.py b/lib/spaces/basic_space.py index 3f04020..9db5723 100644 --- a/lib/spaces/basic_space.py +++ b/lib/spaces/basic_space.py @@ -3,12 +3,15 @@ ##################################################### import abc +import math import random +from typing import Optional + class Space(metaclass=abc.ABCMeta): @abc.abstractmethod - def random(self): + def random(self, recursion=True): raise NotImplementedError @abc.abstractmethod @@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta): class Categorical(Space): - def __init__(self, *data): + def __init__(self, *data, default: Optional[int] = None): self._candidates = [*data] + self._default = default + assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format( + len(self._candidates) + ) def __getitem__(self, index): return self._candidates[index] @@ -27,7 +34,50 @@ class Categorical(Space): return len(self._candidates) def __repr__(self): - return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates) + return "{name:}(candidates={cs:}, default_index={default:})".format( + name=self.__class__.__name__, cs=self._candidates, default=self._default + ) - def random(self): - return random.choice(self._candidates) + def random(self, recursion=True): + sample = random.choice(self._candidates) + if recursion and isinstance(sample, Space): + return sample.random(recursion) + else: + return sample + + +class Continuous(Space): + def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False): + self._lower = lower + self._upper = upper + self._default = default + self._log_scale = log + + @property + def lower(self): + return self._lower + + @property + def upper(self): + return self._upper + + @property + def default(self): + return self._default + + def __repr__(self): + return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( + name=self.__class__.__name__, + lower=self._lower, + upper=self._upper, + default=self._default, + log=self._log_scale, + ) + + def random(self, recursion=True): + del recursion + if self._log_scale: + sample = random.uniform(math.log(self._lower), math.log(self._upper)) + return math.exp(sample) + else: + return random.uniform(self._lower, self._upper) diff --git a/tests/test_basic.py b/tests/test_basic.py index 6757e3a..637e282 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -12,6 +12,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from spaces import Categorical +from spaces import Continuous class TestBasicSpace(unittest.TestCase): @@ -19,4 +20,19 @@ class TestBasicSpace(unittest.TestCase): space = Categorical(1, 2, 3, 4) for i in range(4): self.assertEqual(space[i], i + 1) - self.assertEqual("Categorical(candidates=[1, 2, 3, 4])", str(space)) + self.assertEqual("Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space)) + + def test_continuous(self): + space = Continuous(0, 1) + self.assertGreaterEqual(space.random(), 0) + self.assertGreaterEqual(1, space.random()) + + lower, upper = 1.5, 4.6 + space = Continuous(lower, upper, log=False) + values = [] + for i in range(100000): + x = space.random() + self.assertGreaterEqual(x, lower) + self.assertGreaterEqual(upper, x) + values.append(x) + self.assertAlmostEqual((lower + upper) / 2, sum(values) / len(values), places=2)