Update synthetic codes
This commit is contained in:
parent
870cfc40c2
commit
282038585e
@ -103,11 +103,32 @@ class FitFunc(abc.ABC):
|
||||
)
|
||||
|
||||
|
||||
class LinearFunc(FitFunc):
|
||||
"""The linear function that outputs f(x) = a * x + b."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(LinearFunc, self).__init__(2, list_of_points, params)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x + self._params[1]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x + weights[1]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x + {b})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
)
|
||||
|
||||
|
||||
class QuadraticFunc(FitFunc):
|
||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points)
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points, params)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
|
@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
|
@ -3,9 +3,10 @@
|
||||
#####################################################
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
from .math_core import LinearFunc
|
||||
from .math_core import DynamicLinearFunc
|
||||
from .math_core import DynamicQuadraticFunc
|
||||
from .math_core import ConstantFunc, ComposedSinFunc
|
||||
|
||||
|
||||
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||
@ -32,7 +33,8 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10)
|
||||
amplitude_scale=ConstantFunc(1.0),
|
||||
period_phase_shift=LinearFunc(params={0: 10, 1: 0}),
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
elif version == "v2":
|
||||
|
Loading…
Reference in New Issue
Block a user