Updates
This commit is contained in:
		| @@ -174,7 +174,9 @@ def compare_cl(save_dir): | |||||||
|     ) |     ) | ||||||
|     print(video_cmd + "\n") |     print(video_cmd + "\n") | ||||||
|     os.system(video_cmd) |     os.system(video_cmd) | ||||||
|     os.system("{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)) |     os.system( | ||||||
|  |         "{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def visualize_env(save_dir): | def visualize_env(save_dir): | ||||||
| @@ -307,7 +309,6 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|     dynamic_env = env_info["dynamic_env"] |     dynamic_env = env_info["dynamic_env"] | ||||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp |     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||||
|  |  | ||||||
|  |  | ||||||
|     linewidths = 10 |     linewidths = 10 | ||||||
|     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( |     for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( | ||||||
|         tqdm(dynamic_env, ncols=50) |         tqdm(dynamic_env, ncols=50) | ||||||
| @@ -337,7 +338,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|                 metric(predicts, ori_ally) |                 metric(predicts, ori_ally) | ||||||
|                 predicts = predicts.view(-1).numpy() |                 predicts = predicts.view(-1).numpy() | ||||||
|                 alg2xs[alg].append(idx) |                 alg2xs[alg].append(idx) | ||||||
|                 alg2ys[alg].append(metric.get_info()['mse']) |                 alg2ys[alg].append(metric.get_info()["mse"]) | ||||||
|             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) |             plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg) | ||||||
|  |  | ||||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) |         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||||
| @@ -355,7 +356,14 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         cur_ax = fig.add_subplot(2, 1, 2) |         cur_ax = fig.add_subplot(2, 1, 2) | ||||||
|         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |         for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|             # plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg) |             # plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg) | ||||||
|             cur_ax.plot(alg2xs[alg], alg2ys[alg], color=colors[idx_alg], linestyle='-', linewidth=5, label=alg) |             cur_ax.plot( | ||||||
|  |                 alg2xs[alg], | ||||||
|  |                 alg2ys[alg], | ||||||
|  |                 color=colors[idx_alg], | ||||||
|  |                 linestyle="-", | ||||||
|  |                 linewidth=5, | ||||||
|  |                 label=alg, | ||||||
|  |             ) | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         cur_ax.set_xlabel("Timestamp", fontsize=LabelSize) |         cur_ax.set_xlabel("Timestamp", fontsize=LabelSize) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user