Update synthetic codes
This commit is contained in:
parent
282038585e
commit
9168c62855
@ -203,7 +203,7 @@ def visualize_env(save_dir, version):
|
|||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
cur_ax.set_xlim(-2, 2)
|
cur_ax.set_xlim(-2, 2)
|
||||||
cur_ax.set_ylim(-60, 60)
|
cur_ax.set_ylim(-8, 8)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
cur_ax.set_xlim(-10, 10)
|
cur_ax.set_xlim(-10, 10)
|
||||||
cur_ax.set_ylim(-60, 60)
|
cur_ax.set_ylim(-60, 60)
|
||||||
|
@ -55,6 +55,7 @@ class ComposedSinFunc(FitFunc):
|
|||||||
|
|
||||||
def fit(self, **kwargs):
|
def fit(self, **kwargs):
|
||||||
num_sin_phase = kwargs.get("num_sin_phase", 7)
|
num_sin_phase = kwargs.get("num_sin_phase", 7)
|
||||||
|
sin_speed_use_power = kwargs.get("sin_speed_use_power", True)
|
||||||
min_amplitude = kwargs.get("min_amplitude", 1)
|
min_amplitude = kwargs.get("min_amplitude", 1)
|
||||||
max_amplitude = kwargs.get("max_amplitude", 4)
|
max_amplitude = kwargs.get("max_amplitude", 4)
|
||||||
phase_shift = kwargs.get("phase_shift", 0.0)
|
phase_shift = kwargs.get("phase_shift", 0.0)
|
||||||
@ -67,10 +68,17 @@ class ComposedSinFunc(FitFunc):
|
|||||||
amplitude_scale = kwargs.get("amplitude_scale")
|
amplitude_scale = kwargs.get("amplitude_scale")
|
||||||
if kwargs.get("period_phase_shift", None) is None:
|
if kwargs.get("period_phase_shift", None) is None:
|
||||||
fitting_data = []
|
fitting_data = []
|
||||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
if sin_speed_use_power:
|
||||||
|
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||||
|
else:
|
||||||
|
temp_max_scalar = num_sin_phase - 1
|
||||||
for i in range(num_sin_phase):
|
for i in range(num_sin_phase):
|
||||||
value = (2 ** i) / temp_max_scalar
|
if sin_speed_use_power:
|
||||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
value = (2 ** i) / temp_max_scalar
|
||||||
|
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||||
|
else:
|
||||||
|
value = i / temp_max_scalar
|
||||||
|
next_value = (i + 1) / temp_max_scalar
|
||||||
for _phase in (0, 0.25, 0.5, 0.75):
|
for _phase in (0, 0.25, 0.5, 0.75):
|
||||||
inter_value = value + (next_value - value) * _phase
|
inter_value = value + (next_value - value) * _phase
|
||||||
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
||||||
|
@ -33,8 +33,9 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
|||||||
function = DynamicLinearFunc()
|
function = DynamicLinearFunc()
|
||||||
function_param = dict()
|
function_param = dict()
|
||||||
function_param[0] = ComposedSinFunc(
|
function_param[0] = ComposedSinFunc(
|
||||||
amplitude_scale=ConstantFunc(1.0),
|
amplitude_scale=ConstantFunc(3.0),
|
||||||
period_phase_shift=LinearFunc(params={0: 10, 1: 0}),
|
num_sin_phase=9,
|
||||||
|
sin_speed_use_power=False,
|
||||||
)
|
)
|
||||||
function_param[1] = ConstantFunc(constant=0.9)
|
function_param[1] = ConstantFunc(constant=0.9)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
|
Loading…
Reference in New Issue
Block a user