xautodl/lib/spaces/basic_space.py

84 lines
2.3 KiB
Python
Raw Normal View History

2021-03-17 13:48:43 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
import abc
2021-03-18 07:05:29 +01:00
import math
2021-03-17 13:48:43 +01:00
import random
2021-03-18 07:05:29 +01:00
from typing import Optional
2021-03-17 13:48:43 +01:00
class Space(metaclass=abc.ABCMeta):
@abc.abstractmethod
2021-03-18 07:05:29 +01:00
def random(self, recursion=True):
2021-03-17 13:48:43 +01:00
raise NotImplementedError
@abc.abstractmethod
def __repr__(self):
raise NotImplementedError
class Categorical(Space):
2021-03-18 07:05:29 +01:00
def __init__(self, *data, default: Optional[int] = None):
2021-03-17 13:48:43 +01:00
self._candidates = [*data]
2021-03-18 07:05:29 +01:00
self._default = default
assert self._default is None or 0 <= self._default < len(self._candidates), "default >= {:}".format(
len(self._candidates)
)
2021-03-17 13:48:43 +01:00
def __getitem__(self, index):
return self._candidates[index]
def __len__(self):
return len(self._candidates)
def __repr__(self):
2021-03-18 07:05:29 +01:00
return "{name:}(candidates={cs:}, default_index={default:})".format(
name=self.__class__.__name__, cs=self._candidates, default=self._default
)
def random(self, recursion=True):
sample = random.choice(self._candidates)
if recursion and isinstance(sample, Space):
return sample.random(recursion)
else:
return sample
class Continuous(Space):
def __init__(self, lower: float, upper: float, default: Optional[float] = None, log: bool = False):
self._lower = lower
self._upper = upper
self._default = default
self._log_scale = log
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def default(self):
return self._default
def __repr__(self):
return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
name=self.__class__.__name__,
lower=self._lower,
upper=self._upper,
default=self._default,
log=self._log_scale,
)
2021-03-17 13:48:43 +01:00
2021-03-18 07:05:29 +01:00
def random(self, recursion=True):
del recursion
if self._log_scale:
sample = random.uniform(math.log(self._lower), math.log(self._upper))
return math.exp(sample)
else:
return random.uniform(self._lower, self._upper)