Update synthetic codes

This commit is contained in:
D-X-Y 2021-05-09 19:23:18 +08:00
parent 282038585e
commit 9168c62855
3 changed files with 15 additions and 6 deletions

View File

@ -203,7 +203,7 @@ def visualize_env(save_dir, version):
tick.label.set_fontsize(LabelSize - font_gap)
if version == "v1":
cur_ax.set_xlim(-2, 2)
cur_ax.set_ylim(-60, 60)
cur_ax.set_ylim(-8, 8)
elif version == "v2":
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)

View File

@ -55,6 +55,7 @@ class ComposedSinFunc(FitFunc):
def fit(self, **kwargs):
num_sin_phase = kwargs.get("num_sin_phase", 7)
sin_speed_use_power = kwargs.get("sin_speed_use_power", True)
min_amplitude = kwargs.get("min_amplitude", 1)
max_amplitude = kwargs.get("max_amplitude", 4)
phase_shift = kwargs.get("phase_shift", 0.0)
@ -67,10 +68,17 @@ class ComposedSinFunc(FitFunc):
amplitude_scale = kwargs.get("amplitude_scale")
if kwargs.get("period_phase_shift", None) is None:
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
if sin_speed_use_power:
temp_max_scalar = 2 ** (num_sin_phase - 1)
else:
temp_max_scalar = num_sin_phase - 1
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
if sin_speed_use_power:
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
else:
value = i / temp_max_scalar
next_value = (i + 1) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))

View File

@ -33,8 +33,9 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0),
period_phase_shift=LinearFunc(params={0: 10, 1: 0}),
amplitude_scale=ConstantFunc(3.0),
num_sin_phase=9,
sin_speed_use_power=False,
)
function_param[1] = ConstantFunc(constant=0.9)
elif version == "v2":