diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index 6ed35dd..4797d38 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -5,7 +5,7 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .SearchDatasetWrap import SearchDataset from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc -from .math_dynamic_funcs import DynamicQuadraticFunc +from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ComposedSinFunc diff --git a/lib/datasets/math_adv_funcs.py b/lib/datasets/math_adv_funcs.py index d84a5e0..4d577ad 100644 --- a/lib/datasets/math_adv_funcs.py +++ b/lib/datasets/math_adv_funcs.py @@ -59,18 +59,24 @@ class ComposedSinFunc(FitFunc): max_amplitude = kwargs.get("max_amplitude", 4) phase_shift = kwargs.get("phase_shift", 0.0) # create parameters - amplitude_scale = QuadraticFunc( - [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] - ) - fitting_data = [] - temp_max_scalar = 2 ** (num_sin_phase - 1) - for i in range(num_sin_phase): - value = (2 ** i) / temp_max_scalar - next_value = (2 ** (i + 1)) / temp_max_scalar - for _phase in (0, 0.25, 0.5, 0.75): - inter_value = value + (next_value - value) * _phase - fitting_data.append((inter_value, math.pi * (2 * i + _phase))) - period_phase_shift = QuarticFunc(fitting_data) + if kwargs.get("amplitude_scale", None) is None: + amplitude_scale = QuadraticFunc( + [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] + ) + else: + amplitude_scale = kwargs.get("amplitude_scale") + if kwargs.get("period_phase_shift", None) is None: + fitting_data = [] + temp_max_scalar = 2 ** (num_sin_phase - 1) + for i in range(num_sin_phase): + value = (2 ** i) / temp_max_scalar + next_value = (2 ** (i + 1)) / temp_max_scalar + for _phase in (0, 0.25, 0.5, 0.75): + inter_value = value + (next_value - value) * _phase + fitting_data.append((inter_value, math.pi * (2 * i + _phase))) + period_phase_shift = QuarticFunc(fitting_data) + else: + period_phase_shift = kwargs.get("period_phase_shift") self.set( dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) ) diff --git a/lib/datasets/math_dynamic_funcs.py b/lib/datasets/math_dynamic_funcs.py index 0a86716..e4e43c4 100644 --- a/lib/datasets/math_dynamic_funcs.py +++ b/lib/datasets/math_dynamic_funcs.py @@ -37,6 +37,33 @@ class DynamicFunc(FitFunc): return noise_y +class DynamicLinearFunc(DynamicFunc): + """The dynamic linear function that outputs f(x) = a * x + b. + The a and b is a function of timestamp. + """ + + def __init__(self, params=None): + super(DynamicLinearFunc, self).__init__(3, params) + + def __call__(self, x, timestamp=None): + self.check_valid() + if timestamp is None: + timestamp = self._timestamp + a = self._params[0](timestamp) + b = self._params[1](timestamp) + convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x + a, b = convert_fn(a), convert_fn(b) + return a * x + b + + def __repr__(self): + return "{name}({a} * x + {b}, timestamp={timestamp})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + timestamp=self._timestamp, + ) + + class DynamicQuadraticFunc(DynamicFunc): """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. The a, b, and c is a function of timestamp. diff --git a/lib/datasets/synthetic_example.py b/lib/datasets/synthetic_example.py index f72f15c..f5fea7b 100644 --- a/lib/datasets/synthetic_example.py +++ b/lib/datasets/synthetic_example.py @@ -3,11 +3,20 @@ ##################################################### import copy -from .math_dynamic_funcs import DynamicQuadraticFunc +from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc, ComposedSinFunc from .synthetic_env import SyntheticDEnv +def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): + if indicator == "v1": + return create_example_v1(timestamp_config, num_per_task) + elif indicator == "v2": + return create_example_v2(timestamp_config, num_per_task) + else: + raise ValueError("Unkonwn indicator: {:}".format(indicator)) + + def create_example_v1( timestamp_config=None, num_per_task=5000, @@ -35,3 +44,29 @@ def create_example_v1( dynamic_env.set_oracle_map(copy.deepcopy(function)) return dynamic_env, function + + +def create_example_v2( + timestamp_config=None, + num_per_task=5000, +): + mean_generator = ConstantFunc(0) + std_generator = ConstantFunc(1) + + dynamic_env = SyntheticDEnv( + [mean_generator], + [[std_generator]], + num_per_task=num_per_task, + timestamp_config=timestamp_config, + ) + + function = DynamicLinearFunc() + function_param = dict() + function_param[0] = ComposedSinFunc( + amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) + ) + function_param[1] = ConstantFunc(constant=0.9) + function.set(function_param) + + dynamic_env.set_oracle_map(copy.deepcopy(function)) + return dynamic_env, function diff --git a/tests/test_math_adv.py b/tests/test_math_adv.py index d9ac1d0..c1ca38d 100644 --- a/tests/test_math_adv.py +++ b/tests/test_math_adv.py @@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path: from datasets import QuadraticFunc from datasets import ConstantFunc +from datasets import DynamicLinearFunc from datasets import DynamicQuadraticFunc from datasets import ComposedSinFunc @@ -50,3 +51,20 @@ class TestDynamicFunc(unittest.TestCase): function.set_timestamp(1) print(function(2)) + + def test_simple_linear(self): + timestamps = 30 + function = DynamicLinearFunc() + function_param = dict() + function_param[0] = ComposedSinFunc( + num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 + ) + function_param[1] = ConstantFunc(constant=0.9) + function.set(function_param) + print(function) + + with self.assertRaises(TypeError) as context: + function(0) + + function.set_timestamp(1) + print(function(2))