diff --git a/lib/datasets/math_base_funcs.py b/lib/datasets/math_base_funcs.py index 42a4bd4..a77634b 100644 --- a/lib/datasets/math_base_funcs.py +++ b/lib/datasets/math_base_funcs.py @@ -103,11 +103,32 @@ class FitFunc(abc.ABC): ) +class LinearFunc(FitFunc): + """The linear function that outputs f(x) = a * x + b.""" + + def __init__(self, list_of_points=None, params=None): + super(LinearFunc, self).__init__(2, list_of_points, params) + + def __call__(self, x): + self.check_valid() + return self._params[0] * x + self._params[1] + + def _getitem(self, x, weights): + return weights[0] * x + weights[1] + + def __repr__(self): + return "{name}({a} * x + {b})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + ) + + class QuadraticFunc(FitFunc): """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" - def __init__(self, list_of_points=None): - super(QuadraticFunc, self).__init__(3, list_of_points) + def __init__(self, list_of_points=None, params=None): + super(QuadraticFunc, self).__init__(3, list_of_points, params) def __call__(self, x): self.check_valid() diff --git a/lib/datasets/math_core.py b/lib/datasets/math_core.py index 6b12d88..5dd2429 100644 --- a/lib/datasets/math_core.py +++ b/lib/datasets/math_core.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # ##################################################### -from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc +from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc from .math_dynamic_funcs import DynamicLinearFunc from .math_dynamic_funcs import DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc diff --git a/lib/datasets/synthetic_core.py b/lib/datasets/synthetic_core.py index f05024b..9d05d7a 100644 --- a/lib/datasets/synthetic_core.py +++ b/lib/datasets/synthetic_core.py @@ -3,9 +3,10 @@ ##################################################### from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv -from .math_dynamic_funcs import DynamicLinearFunc -from .math_dynamic_funcs import DynamicQuadraticFunc -from .math_adv_funcs import ConstantFunc, ComposedSinFunc +from .math_core import LinearFunc +from .math_core import DynamicLinearFunc +from .math_core import DynamicQuadraticFunc +from .math_core import ConstantFunc, ComposedSinFunc __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] @@ -32,7 +33,8 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio function = DynamicLinearFunc() function_param = dict() function_param[0] = ComposedSinFunc( - amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10) + amplitude_scale=ConstantFunc(1.0), + period_phase_shift=LinearFunc(params={0: 10, 1: 0}), ) function_param[1] = ConstantFunc(constant=0.9) elif version == "v2":