Add one more synthetic env
This commit is contained in:
parent
34560ad8d1
commit
853a702926
@ -5,7 +5,7 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
|||||||
from .SearchDatasetWrap import SearchDataset
|
from .SearchDatasetWrap import SearchDataset
|
||||||
|
|
||||||
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc
|
||||||
from .math_adv_funcs import ConstantFunc
|
from .math_adv_funcs import ConstantFunc
|
||||||
from .math_adv_funcs import ComposedSinFunc
|
from .math_adv_funcs import ComposedSinFunc
|
||||||
|
|
||||||
|
@ -59,18 +59,24 @@ class ComposedSinFunc(FitFunc):
|
|||||||
max_amplitude = kwargs.get("max_amplitude", 4)
|
max_amplitude = kwargs.get("max_amplitude", 4)
|
||||||
phase_shift = kwargs.get("phase_shift", 0.0)
|
phase_shift = kwargs.get("phase_shift", 0.0)
|
||||||
# create parameters
|
# create parameters
|
||||||
amplitude_scale = QuadraticFunc(
|
if kwargs.get("amplitude_scale", None) is None:
|
||||||
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
|
amplitude_scale = QuadraticFunc(
|
||||||
)
|
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
|
||||||
fitting_data = []
|
)
|
||||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
else:
|
||||||
for i in range(num_sin_phase):
|
amplitude_scale = kwargs.get("amplitude_scale")
|
||||||
value = (2 ** i) / temp_max_scalar
|
if kwargs.get("period_phase_shift", None) is None:
|
||||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
fitting_data = []
|
||||||
for _phase in (0, 0.25, 0.5, 0.75):
|
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||||
inter_value = value + (next_value - value) * _phase
|
for i in range(num_sin_phase):
|
||||||
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
value = (2 ** i) / temp_max_scalar
|
||||||
period_phase_shift = QuarticFunc(fitting_data)
|
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||||
|
for _phase in (0, 0.25, 0.5, 0.75):
|
||||||
|
inter_value = value + (next_value - value) * _phase
|
||||||
|
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
||||||
|
period_phase_shift = QuarticFunc(fitting_data)
|
||||||
|
else:
|
||||||
|
period_phase_shift = kwargs.get("period_phase_shift")
|
||||||
self.set(
|
self.set(
|
||||||
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
|
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,33 @@ class DynamicFunc(FitFunc):
|
|||||||
return noise_y
|
return noise_y
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicLinearFunc(DynamicFunc):
|
||||||
|
"""The dynamic linear function that outputs f(x) = a * x + b.
|
||||||
|
The a and b is a function of timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params=None):
|
||||||
|
super(DynamicLinearFunc, self).__init__(3, params)
|
||||||
|
|
||||||
|
def __call__(self, x, timestamp=None):
|
||||||
|
self.check_valid()
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = self._timestamp
|
||||||
|
a = self._params[0](timestamp)
|
||||||
|
b = self._params[1](timestamp)
|
||||||
|
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||||
|
a, b = convert_fn(a), convert_fn(b)
|
||||||
|
return a * x + b
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * x + {b}, timestamp={timestamp})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
timestamp=self._timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DynamicQuadraticFunc(DynamicFunc):
|
class DynamicQuadraticFunc(DynamicFunc):
|
||||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||||
The a, b, and c is a function of timestamp.
|
The a, b, and c is a function of timestamp.
|
||||||
|
@ -3,11 +3,20 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
|
||||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||||
from .synthetic_env import SyntheticDEnv
|
from .synthetic_env import SyntheticDEnv
|
||||||
|
|
||||||
|
|
||||||
|
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
|
||||||
|
if indicator == "v1":
|
||||||
|
return create_example_v1(timestamp_config, num_per_task)
|
||||||
|
elif indicator == "v2":
|
||||||
|
return create_example_v2(timestamp_config, num_per_task)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unkonwn indicator: {:}".format(indicator))
|
||||||
|
|
||||||
|
|
||||||
def create_example_v1(
|
def create_example_v1(
|
||||||
timestamp_config=None,
|
timestamp_config=None,
|
||||||
num_per_task=5000,
|
num_per_task=5000,
|
||||||
@ -35,3 +44,29 @@ def create_example_v1(
|
|||||||
|
|
||||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||||
return dynamic_env, function
|
return dynamic_env, function
|
||||||
|
|
||||||
|
|
||||||
|
def create_example_v2(
|
||||||
|
timestamp_config=None,
|
||||||
|
num_per_task=5000,
|
||||||
|
):
|
||||||
|
mean_generator = ConstantFunc(0)
|
||||||
|
std_generator = ConstantFunc(1)
|
||||||
|
|
||||||
|
dynamic_env = SyntheticDEnv(
|
||||||
|
[mean_generator],
|
||||||
|
[[std_generator]],
|
||||||
|
num_per_task=num_per_task,
|
||||||
|
timestamp_config=timestamp_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
function = DynamicLinearFunc()
|
||||||
|
function_param = dict()
|
||||||
|
function_param[0] = ComposedSinFunc(
|
||||||
|
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
|
||||||
|
)
|
||||||
|
function_param[1] = ConstantFunc(constant=0.9)
|
||||||
|
function.set(function_param)
|
||||||
|
|
||||||
|
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||||
|
return dynamic_env, function
|
||||||
|
@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path:
|
|||||||
|
|
||||||
from datasets import QuadraticFunc
|
from datasets import QuadraticFunc
|
||||||
from datasets import ConstantFunc
|
from datasets import ConstantFunc
|
||||||
|
from datasets import DynamicLinearFunc
|
||||||
from datasets import DynamicQuadraticFunc
|
from datasets import DynamicQuadraticFunc
|
||||||
from datasets import ComposedSinFunc
|
from datasets import ComposedSinFunc
|
||||||
|
|
||||||
@ -50,3 +51,20 @@ class TestDynamicFunc(unittest.TestCase):
|
|||||||
|
|
||||||
function.set_timestamp(1)
|
function.set_timestamp(1)
|
||||||
print(function(2))
|
print(function(2))
|
||||||
|
|
||||||
|
def test_simple_linear(self):
|
||||||
|
timestamps = 30
|
||||||
|
function = DynamicLinearFunc()
|
||||||
|
function_param = dict()
|
||||||
|
function_param[0] = ComposedSinFunc(
|
||||||
|
num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||||
|
)
|
||||||
|
function_param[1] = ConstantFunc(constant=0.9)
|
||||||
|
function.set(function_param)
|
||||||
|
print(function)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError) as context:
|
||||||
|
function(0)
|
||||||
|
|
||||||
|
function.set_timestamp(1)
|
||||||
|
print(function(2))
|
||||||
|
Loading…
Reference in New Issue
Block a user