Add simple spaces

This commit is contained in:
D-X-Y 2021-03-18 14:05:29 +08:00
parent 7de8c0dec4
commit 85ee0ad4eb
4 changed files with 74 additions and 7 deletions

View File

@ -11,7 +11,7 @@ The Python files in this folder are used to re-produce the results in ``NATS-Ben
## Requirements
- `nats_bench`>=v1.1 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench)
- `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench)
- `hpbandster` : if you want to run BOHB
## Citation

View File

@ -5,3 +5,4 @@
#####################################################
from .basic_space import Categorical
from .basic_space import Continuous

View File

@ -3,12 +3,15 @@
#####################################################
import abc
import math
import random
from typing import Optional
class Space(metaclass=abc.ABCMeta):
@abc.abstractmethod
def random(self):
def random(self, recursion=True):
raise NotImplementedError
@abc.abstractmethod
@ -17,8 +20,12 @@ class Space(metaclass=abc.ABCMeta):
class Categorical(Space):
def __init__(self, *data):
def __init__(self, *data, default: Optional[int] = None):
self._candidates = [*data]
self._default = default
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
len(self._candidates)
)
def __getitem__(self, index):
return self._candidates[index]
@ -27,7 +34,50 @@ class Categorical(Space):
return len(self._candidates)
def __repr__(self):
return "{name:}(candidates={cs:})".format(name=self.__class__.__name__, cs=self._candidates)
return "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
def random(self):
return random.choice(self._candidates)
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
return sample.random(recursion)
else:
return sample
class Continuous(Space):
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def default(self):
return self._default
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
def random(self, recursion=True):
del recursion
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample)
else:
return random.uniform(self._lower, self._upper)

View File

@ -12,6 +12,7 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from spaces import Categorical
from spaces import Continuous
class TestBasicSpace(unittest.TestCase):
@ -19,4 +20,19 @@ class TestBasicSpace(unittest.TestCase):
space = Categorical(1, 2, 3, 4)
for i in range(4):
self.assertEqual(space[i], i + 1)
self.assertEqual("Categorical(candidates=[1, 2, 3, 4])", str(space))
self.assertEqual("Categorical(candidates=[1, 2, 3, 4], default_index=None)", str(space))
def test_continuous(self):
space = Continuous(0, 1)
self.assertGreaterEqual(space.random(), 0)
self.assertGreaterEqual(1, space.random())
lower, upper = 1.5, 4.6
space = Continuous(lower, upper, log=False)
values = []
for i in range(100000):
x = space.random()
self.assertGreaterEqual(x, lower)
self.assertGreaterEqual(upper, x)
values.append(x)
self.assertAlmostEqual((lower + upper) / 2, sum(values) / len(values), places=2)