diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 5b6ea93..83ee170 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -203,7 +203,7 @@ def visualize_env(save_dir, version): tick.label.set_fontsize(LabelSize - font_gap) if version == "v1": cur_ax.set_xlim(-2, 2) - cur_ax.set_ylim(-60, 60) + cur_ax.set_ylim(-8, 8) elif version == "v2": cur_ax.set_xlim(-10, 10) cur_ax.set_ylim(-60, 60) diff --git a/lib/datasets/math_adv_funcs.py b/lib/datasets/math_adv_funcs.py index 4d577ad..2a093c7 100644 --- a/lib/datasets/math_adv_funcs.py +++ b/lib/datasets/math_adv_funcs.py @@ -55,6 +55,7 @@ class ComposedSinFunc(FitFunc): def fit(self, **kwargs): num_sin_phase = kwargs.get("num_sin_phase", 7) + sin_speed_use_power = kwargs.get("sin_speed_use_power", True) min_amplitude = kwargs.get("min_amplitude", 1) max_amplitude = kwargs.get("max_amplitude", 4) phase_shift = kwargs.get("phase_shift", 0.0) @@ -67,10 +68,17 @@ class ComposedSinFunc(FitFunc): amplitude_scale = kwargs.get("amplitude_scale") if kwargs.get("period_phase_shift", None) is None: fitting_data = [] - temp_max_scalar = 2 ** (num_sin_phase - 1) + if sin_speed_use_power: + temp_max_scalar = 2 ** (num_sin_phase - 1) + else: + temp_max_scalar = num_sin_phase - 1 for i in range(num_sin_phase): - value = (2 ** i) / temp_max_scalar - next_value = (2 ** (i + 1)) / temp_max_scalar + if sin_speed_use_power: + value = (2 ** i) / temp_max_scalar + next_value = (2 ** (i + 1)) / temp_max_scalar + else: + value = i / temp_max_scalar + next_value = (i + 1) / temp_max_scalar for _phase in (0, 0.25, 0.5, 0.75): inter_value = value + (next_value - value) * _phase fitting_data.append((inter_value, math.pi * (2 * i + _phase))) diff --git a/lib/datasets/synthetic_core.py b/lib/datasets/synthetic_core.py index 9d05d7a..0fb7238 100644 --- a/lib/datasets/synthetic_core.py +++ b/lib/datasets/synthetic_core.py @@ -33,8 +33,9 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio function = DynamicLinearFunc() function_param = dict() function_param[0] = ComposedSinFunc( - amplitude_scale=ConstantFunc(1.0), - period_phase_shift=LinearFunc(params={0: 10, 1: 0}), + amplitude_scale=ConstantFunc(3.0), + num_sin_phase=9, + sin_speed_use_power=False, ) function_param[1] = ConstantFunc(constant=0.9) elif version == "v2":