Refine LFNA vis codes
This commit is contained in:
		| @@ -1,7 +1,8 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| # python exps/LFNA/vis-synthetic.py                                        # | # python exps/LFNA/vis-synthetic.py --env_version v1                       # | ||||||
|  | # python exps/LFNA/vis-synthetic.py --env_version v2                       # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| import os, sys, copy, random | import os, sys, copy, random | ||||||
| import torch | import torch | ||||||
| @@ -223,7 +224,9 @@ def visualize_env(save_dir, version): | |||||||
|  |  | ||||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     for substr in ("pdf", "png"): | ||||||
|  |       sub_save_dir = save_dir / substr | ||||||
|  |       sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|     dpi, width, height = 30, 3200, 2000 |     dpi, width, height = 30, 3200, 2000 | ||||||
|     figsize = width / float(dpi), height / float(dpi) |     figsize = width / float(dpi), height / float(dpi) | ||||||
| @@ -235,10 +238,10 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | |||||||
|  |  | ||||||
|     alg_name2dir = OrderedDict() |     alg_name2dir = OrderedDict() | ||||||
|     alg_name2dir["Optimal"] = "use-same-timestamp" |     alg_name2dir["Optimal"] = "use-same-timestamp" | ||||||
|     alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" |     # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" | ||||||
|     alg_name2dir["MAML"] = "use-maml-s1" |     # alg_name2dir["MAML"] = "use-maml-s1" | ||||||
|     alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" |     # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" | ||||||
|     alg_name2dir["LFNA (debug)"] = "lfna-debug" |     alg_name2dir["LFNA (debug)"] = "lfna-tall-hpnet" | ||||||
|     alg_name2all_containers = OrderedDict() |     alg_name2all_containers = OrderedDict() | ||||||
|     if version == "v1": |     if version == "v1": | ||||||
|         poststr = "v1-d16" |         poststr = "v1-d16" | ||||||
| @@ -246,15 +249,16 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         raise ValueError("Invalid version: {:}".format(version)) |         raise ValueError("Invalid version: {:}".format(version)) | ||||||
|     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): |     for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): | ||||||
|         ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" |         ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" | ||||||
|         xdata = torch.load(ckp_path) |         xdata = torch.load(ckp_path, map_location="cpu") | ||||||
|         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] |         alg_name2all_containers[alg] = xdata["w_container_per_epoch"] | ||||||
|     # load the basic model |     # load the basic model | ||||||
|     model = get_model( |     model = get_model( | ||||||
|         dict(model_type="simple_mlp"), |         dict(model_type="norm_mlp"), | ||||||
|         act_cls="leaky_relu", |  | ||||||
|         norm_cls="identity", |  | ||||||
|         input_dim=1, |         input_dim=1, | ||||||
|         output_dim=1, |         output_dim=1, | ||||||
|  |         hidden_dims=[16] * 2, | ||||||
|  |         act_cls="gelu", | ||||||
|  |         norm_cls="layer_norm_1d", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     alg2xs, alg2ys = defaultdict(list), defaultdict(list) |     alg2xs, alg2ys = defaultdict(list), defaultdict(list) | ||||||
| @@ -331,13 +335,14 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | |||||||
|         cur_ax.set_ylim(0, 10) |         cur_ax.set_ylim(0, 10) | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         save_path = save_dir / "v{:}-{:05d}".format(version, idx) |         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") |         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||||
|  |         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||||
|         plt.close("all") |         plt.close("all") | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( |     base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( | ||||||
|         xdir=save_dir, w=width, h=height, ver=version |         xdir=save_dir / "png", w=width, h=height, ver=version | ||||||
|     ) |     ) | ||||||
|     os.system( |     os.system( | ||||||
|         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) |         "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) | ||||||
| @@ -356,9 +361,15 @@ if __name__ == "__main__": | |||||||
|         default="./outputs/vis-synthetic", |         default="./outputs/vis-synthetic", | ||||||
|         help="The save directory.", |         help="The save directory.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--env_version", | ||||||
|  |         type=str, | ||||||
|  |         required=True, | ||||||
|  |         help="The synthetic enviornment version.", | ||||||
|  |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1") |     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) |     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user