Update synthetic codes
This commit is contained in:
parent
282038585e
commit
9168c62855
@ -203,7 +203,7 @@ def visualize_env(save_dir, version):
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
if version == "v1":
|
||||
cur_ax.set_xlim(-2, 2)
|
||||
cur_ax.set_ylim(-60, 60)
|
||||
cur_ax.set_ylim(-8, 8)
|
||||
elif version == "v2":
|
||||
cur_ax.set_xlim(-10, 10)
|
||||
cur_ax.set_ylim(-60, 60)
|
||||
|
@ -55,6 +55,7 @@ class ComposedSinFunc(FitFunc):
|
||||
|
||||
def fit(self, **kwargs):
|
||||
num_sin_phase = kwargs.get("num_sin_phase", 7)
|
||||
sin_speed_use_power = kwargs.get("sin_speed_use_power", True)
|
||||
min_amplitude = kwargs.get("min_amplitude", 1)
|
||||
max_amplitude = kwargs.get("max_amplitude", 4)
|
||||
phase_shift = kwargs.get("phase_shift", 0.0)
|
||||
@ -67,10 +68,17 @@ class ComposedSinFunc(FitFunc):
|
||||
amplitude_scale = kwargs.get("amplitude_scale")
|
||||
if kwargs.get("period_phase_shift", None) is None:
|
||||
fitting_data = []
|
||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||
if sin_speed_use_power:
|
||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||
else:
|
||||
temp_max_scalar = num_sin_phase - 1
|
||||
for i in range(num_sin_phase):
|
||||
value = (2 ** i) / temp_max_scalar
|
||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||
if sin_speed_use_power:
|
||||
value = (2 ** i) / temp_max_scalar
|
||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||
else:
|
||||
value = i / temp_max_scalar
|
||||
next_value = (i + 1) / temp_max_scalar
|
||||
for _phase in (0, 0.25, 0.5, 0.75):
|
||||
inter_value = value + (next_value - value) * _phase
|
||||
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
||||
|
@ -33,8 +33,9 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(1.0),
|
||||
period_phase_shift=LinearFunc(params={0: 10, 1: 0}),
|
||||
amplitude_scale=ConstantFunc(3.0),
|
||||
num_sin_phase=9,
|
||||
sin_speed_use_power=False,
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
elif version == "v2":
|
||||
|
Loading…
Reference in New Issue
Block a user