Add one more synthetic env

This commit is contained in:
D-X-Y 2021-05-09 18:37:37 +08:00
parent 34560ad8d1
commit 853a702926
5 changed files with 100 additions and 14 deletions

View File

@ -5,7 +5,7 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset from .SearchDatasetWrap import SearchDataset
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .math_dynamic_funcs import DynamicQuadraticFunc from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc
from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc from .math_adv_funcs import ComposedSinFunc

View File

@ -59,18 +59,24 @@ class ComposedSinFunc(FitFunc):
max_amplitude = kwargs.get("max_amplitude", 4) max_amplitude = kwargs.get("max_amplitude", 4)
phase_shift = kwargs.get("phase_shift", 0.0) phase_shift = kwargs.get("phase_shift", 0.0)
# create parameters # create parameters
amplitude_scale = QuadraticFunc( if kwargs.get("amplitude_scale", None) is None:
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] amplitude_scale = QuadraticFunc(
) [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
fitting_data = [] )
temp_max_scalar = 2 ** (num_sin_phase - 1) else:
for i in range(num_sin_phase): amplitude_scale = kwargs.get("amplitude_scale")
value = (2 ** i) / temp_max_scalar if kwargs.get("period_phase_shift", None) is None:
next_value = (2 ** (i + 1)) / temp_max_scalar fitting_data = []
for _phase in (0, 0.25, 0.5, 0.75): temp_max_scalar = 2 ** (num_sin_phase - 1)
inter_value = value + (next_value - value) * _phase for i in range(num_sin_phase):
fitting_data.append((inter_value, math.pi * (2 * i + _phase))) value = (2 ** i) / temp_max_scalar
period_phase_shift = QuarticFunc(fitting_data) next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
period_phase_shift = QuarticFunc(fitting_data)
else:
period_phase_shift = kwargs.get("period_phase_shift")
self.set( self.set(
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
) )

View File

@ -37,6 +37,33 @@ class DynamicFunc(FitFunc):
return noise_y return noise_y
class DynamicLinearFunc(DynamicFunc):
"""The dynamic linear function that outputs f(x) = a * x + b.
The a and b is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicLinearFunc, self).__init__(3, params)
def __call__(self, x, timestamp=None):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
a = self._params[0](timestamp)
b = self._params[1](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b = convert_fn(a), convert_fn(b)
return a * x + b
def __repr__(self):
return "{name}({a} * x + {b}, timestamp={timestamp})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
timestamp=self._timestamp,
)
class DynamicQuadraticFunc(DynamicFunc): class DynamicQuadraticFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp. The a, b, and c is a function of timestamp.

View File

@ -3,11 +3,20 @@
##################################################### #####################################################
import copy import copy
from .math_dynamic_funcs import DynamicQuadraticFunc from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
if indicator == "v1":
return create_example_v1(timestamp_config, num_per_task)
elif indicator == "v2":
return create_example_v2(timestamp_config, num_per_task)
else:
raise ValueError("Unkonwn indicator: {:}".format(indicator))
def create_example_v1( def create_example_v1(
timestamp_config=None, timestamp_config=None,
num_per_task=5000, num_per_task=5000,
@ -35,3 +44,29 @@ def create_example_v1(
dynamic_env.set_oracle_map(copy.deepcopy(function)) dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function return dynamic_env, function
def create_example_v2(
timestamp_config=None,
num_per_task=5000,
):
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=timestamp_config,
)
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
)
function_param[1] = ConstantFunc(constant=0.9)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function

View File

@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path:
from datasets import QuadraticFunc from datasets import QuadraticFunc
from datasets import ConstantFunc from datasets import ConstantFunc
from datasets import DynamicLinearFunc
from datasets import DynamicQuadraticFunc from datasets import DynamicQuadraticFunc
from datasets import ComposedSinFunc from datasets import ComposedSinFunc
@ -50,3 +51,20 @@ class TestDynamicFunc(unittest.TestCase):
function.set_timestamp(1) function.set_timestamp(1)
print(function(2)) print(function(2))
def test_simple_linear(self):
timestamps = 30
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantFunc(constant=0.9)
function.set(function_param)
print(function)
with self.assertRaises(TypeError) as context:
function(0)
function.set_timestamp(1)
print(function(2))