Update synthetic codes

This commit is contained in:
D-X-Y 2021-05-09 19:11:56 +08:00
parent 870cfc40c2
commit 282038585e
3 changed files with 30 additions and 7 deletions

View File

@ -103,11 +103,32 @@ class FitFunc(abc.ABC):
)
class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None):
super(LinearFunc, self).__init__(2, list_of_points, params)
def __call__(self, x):
self.check_valid()
return self._params[0] * x + self._params[1]
def _getitem(self, x, weights):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * x + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None):
super(QuadraticFunc, self).__init__(3, list_of_points)
def __init__(self, list_of_points=None, params=None):
super(QuadraticFunc, self).__init__(3, list_of_points, params)
def __call__(self, x):
self.check_valid()

View File

@ -1,7 +1,7 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc

View File

@ -3,9 +3,10 @@
#####################################################
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
@ -32,7 +33,8 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10)
amplitude_scale=ConstantFunc(1.0),
period_phase_shift=LinearFunc(params={0: 10, 1: 0}),
)
function_param[1] = ConstantFunc(constant=0.9)
elif version == "v2":