Updates
This commit is contained in:
parent
3117d4f5f5
commit
3586c2d996
@ -174,7 +174,9 @@ def compare_cl(save_dir):
|
|||||||
)
|
)
|
||||||
print(video_cmd + "\n")
|
print(video_cmd + "\n")
|
||||||
os.system(video_cmd)
|
os.system(video_cmd)
|
||||||
os.system("{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir))
|
os.system(
|
||||||
|
"{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def visualize_env(save_dir):
|
def visualize_env(save_dir):
|
||||||
@ -307,7 +309,6 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
dynamic_env = env_info["dynamic_env"]
|
dynamic_env = env_info["dynamic_env"]
|
||||||
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
||||||
|
|
||||||
|
|
||||||
linewidths = 10
|
linewidths = 10
|
||||||
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
|
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
|
||||||
tqdm(dynamic_env, ncols=50)
|
tqdm(dynamic_env, ncols=50)
|
||||||
@ -337,7 +338,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
metric(predicts, ori_ally)
|
metric(predicts, ori_ally)
|
||||||
predicts = predicts.view(-1).numpy()
|
predicts = predicts.view(-1).numpy()
|
||||||
alg2xs[alg].append(idx)
|
alg2xs[alg].append(idx)
|
||||||
alg2ys[alg].append(metric.get_info()['mse'])
|
alg2ys[alg].append(metric.get_info()["mse"])
|
||||||
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
|
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
|
||||||
|
|
||||||
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||||
@ -355,7 +356,14 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
cur_ax = fig.add_subplot(2, 1, 2)
|
cur_ax = fig.add_subplot(2, 1, 2)
|
||||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||||
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
|
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
|
||||||
cur_ax.plot(alg2xs[alg], alg2ys[alg], color=colors[idx_alg], linestyle='-', linewidth=5, label=alg)
|
cur_ax.plot(
|
||||||
|
alg2xs[alg],
|
||||||
|
alg2ys[alg],
|
||||||
|
color=colors[idx_alg],
|
||||||
|
linestyle="-",
|
||||||
|
linewidth=5,
|
||||||
|
label=alg,
|
||||||
|
)
|
||||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||||
|
|
||||||
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
|
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
|
||||||
|
Loading…
Reference in New Issue
Block a user