Add simple spaces
This commit is contained in:
parent
7de8c0dec4
commit
85ee0ad4eb
@ -11,7 +11,7 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben
|
||||
|
||||
## Requirements
|
||||
|
||||
- `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench)
|
||||
- `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench)
|
||||
- `hpbandster` : if you want to run BOHB
|
||||
|
||||
## Citation
|
||||
|
@ -5,3 +5,4 @@
|
||||
#####################################################
|
||||
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import Continuous
|
||||
|
@ -3,12 +3,15 @@
|
||||
#####################################################
|
||||
|
||||
import abc
|
||||
import math
|
||||
import random
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Space(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def random(self):
|
||||
def random(self, recursion=True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta):
|
||||
|
||||
|
||||
class Categorical(Space):
|
||||
def __init__(self, *data):
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
|
||||
len(self._candidates)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._candidates[index]
|
||||
@ -27,7 +34,50 @@ class Categorical(Space):
|
||||
return len(self._candidates)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates)
|
||||
return "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
|
||||
def random(self):
|
||||
return random.choice(self._candidates)
|
||||
def random(self, recursion=True):
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
return sample.random(recursion)
|
||||
else:
|
||||
return sample
|
||||
|
||||
|
||||
class Continuous(Space):
|
||||
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
self._log_scale = log
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
def __repr__(self):
|
||||
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
upper=self._upper,
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
|
||||
def random(self, recursion=True):
|
||||
del recursion
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
return math.exp(sample)
|
||||
else:
|
||||
return random.uniform(self._lower, self._upper)
|
||||
|
@ -12,6 +12,7 @@ if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from spaces import Categorical
|
||||
from spaces import Continuous
|
||||
|
||||
|
||||
class TestBasicSpace(unittest.TestCase):
|
||||
@ -19,4 +20,19 @@ class TestBasicSpace(unittest.TestCase):
|
||||
space = Categorical(1, 2, 3, 4)
|
||||
for i in range(4):
|
||||
self.assertEqual(space[i], i + 1)
|
||||
self.assertEqual("Categorical(candidates=[1, 2, 3, 4])", str(space))
|
||||
self.assertEqual("Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space))
|
||||
|
||||
def test_continuous(self):
|
||||
space = Continuous(0, 1)
|
||||
self.assertGreaterEqual(space.random(), 0)
|
||||
self.assertGreaterEqual(1, space.random())
|
||||
|
||||
lower, upper = 1.5, 4.6
|
||||
space = Continuous(lower, upper, log=False)
|
||||
values = []
|
||||
for i in range(100000):
|
||||
x = space.random()
|
||||
self.assertGreaterEqual(x, lower)
|
||||
self.assertGreaterEqual(upper, x)
|
||||
values.append(x)
|
||||
self.assertAlmostEqual((lower + upper) / 2, sum(values) / len(values), places=2)
|
||||
|
Loading…
Reference in New Issue
Block a user