From 1209fffbaaf085f6b7ca2c5b68ed8e8e134e15ad Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 29 Apr 2021 04:48:21 -0700 Subject: [PATCH] Upgrade same/his --- .gitignore | 2 + exps/LFNA/basic-his.py | 3 +- exps/LFNA/basic-same.py | 3 +- exps/LFNA/vis-synthetic.py | 87 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 90 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 20a318b..6b6c19a 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,5 @@ outputs pytest_cache *.pkl *.pth + +*.tgz diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index c8b369b..2506ceb 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -136,7 +136,8 @@ def main(args): ) save_checkpoint( { - "model": model.state_dict(), + "model_state_dict": model.state_dict(), + "model": model, "index": idx, "timestamp": env_info["{:}-timestamp".format(idx)], }, diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index 0a889a9..4fcdf5d 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -132,7 +132,8 @@ def main(args): ) save_checkpoint( { - "model": model.state_dict(), + "model_state_dict": model.state_dict(), + "model": model, "index": idx, "timestamp": env_info["{:}-timestamp".format(idx)], }, diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 99f7fbb..4e9c972 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -213,7 +213,87 @@ def visualize_env(save_dir): xdir=save_dir ) os.system("{:} {xdir}/env.mp4".format(base_cmd, xdir=save_dir)) - os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) + os.system("{:} {xdir}/env.webm".format(base_cmd, xdir=save_dir)) + + +def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"): + save_dir = Path(str(save_dir)) + save_dir.mkdir(parents=True, exist_ok=True) + + dpi, width, height = 30, 1800, 1400 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize, font_gap = 80, 80, 5 + + cache_path = Path(alg_dir) / "env-info.pth" + assert cache_path.exists(), "{:} does not exist".format(cache_path) + env_info = torch.load(cache_path) + + alg_name2dir = {"Optimal": "use-same-timestamp", "History SL": "use-all-past-data"} + colors = ["r", "g"] + + dynamic_env = env_info["dynamic_env"] + min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp + for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( + tqdm(dynamic_env, ncols=50) + ): + if idx == 0: + continue + fig = plt.figure(figsize=figsize) + cur_ax = fig.add_subplot(1, 1, 1) + + # the data + allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy() + cur_ax.scatter( + allx, + ally, + color="k", + alpha=0.99, + s=10, + label=None, + ) + + for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): + ckp_path = ( + Path(alg_dir) + / xdir + / "{:04d}-{:04d}.pth".format(idx, env_info["total"]) + ) + assert ckp_path.exists() + ckp_data = torch.load(ckp_path) + with torch.no_grad(): + predicts = ckp_data["model"](ori_allx) + predicts = predicts.cpu().view(-1).numpy() + cur_ax.scatter( + allx, + predicts, + color=colors[idx_alg], + alpha=0.99, + s=20, + label=alg, + ) + + cur_ax.set_xlabel("X", fontsize=LabelSize) + cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) + for tick in cur_ax.xaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize - font_gap) + tick.label.set_rotation(10) + for tick in cur_ax.yaxis.get_major_ticks(): + tick.label.set_fontsize(LabelSize - font_gap) + cur_ax.set_xlim(-10, 10) + cur_ax.set_ylim(-60, 60) + cur_ax.legend(loc=1, fontsize=LegendFontsize) + + save_path = save_dir / "{:05d}".format(idx) + fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") + fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") + plt.close("all") + save_dir = save_dir.resolve() + base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( + xdir=save_dir, w=width, h=height + ) + os.system("{:} {xdir}/compare_alg.mp4".format(base_cmd, xdir=save_dir)) + os.system("{:} {xdir}/compare_alg.webm".format(base_cmd, xdir=save_dir)) + # the trajectory data if __name__ == "__main__": @@ -227,5 +307,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - visualize_env(os.path.join(args.save_dir, "vis-env")) - compare_cl(os.path.join(args.save_dir, "compare-cl")) + compare_algs(os.path.join(args.save_dir, "compare-alg")) + # visualize_env(os.path.join(args.save_dir, "vis-env")) + # compare_cl(os.path.join(args.save_dir, "compare-cl"))