xautodl/lib/datasets/synthetic_core.py
2021-05-09 19:05:07 +08:00

55 lines
2.0 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"):
if version == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
elif version == "v2":
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
else:
raise ValueError("Unknown version: {:}".format(version))
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=dict(
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
),
)
if version == "v1":
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10)
)
function_param[1] = ConstantFunc(constant=0.9)
elif version == "v2":
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = ComposedSinFunc(
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
else:
raise ValueError("Unknown version: {:}".format(version))
function.set(function_param)
# dynamic_env.set_oracle_map(copy.deepcopy(function))
dynamic_env.set_oracle_map(function)
return dynamic_env