From 6e7b1c551f1d871f626a156706d0b7ecf656b451 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 9 May 2021 23:36:55 +0800 Subject: [PATCH] Update synthetic --- exps/LFNA/basic-his.py | 15 ++++++++++++--- exps/LFNA/vis-synthetic.py | 6 +++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index 857d1da..216ea66 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-his.py --srange 1-999 +# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -36,12 +36,14 @@ def main(args): prepare_seed(args.rand_seed) logger = prepare_logger(args) - cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() + cache_path = ( + logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) + ).resolve() if cache_path.exists(): env_info = torch.load(cache_path) else: env_info = dict() - dynamic_env = get_synthetic_env() + dynamic_env = get_synthetic_env(version=args.env_version) env_info["total"] = len(dynamic_env) for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): env_info["{:}-timestamp".format(idx)] = timestamp @@ -169,6 +171,12 @@ if __name__ == "__main__": default="./outputs/lfna-synthetic/use-all-past-data", help="The checkpoint directory.", ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) parser.add_argument( "--init_lr", type=float, @@ -202,4 +210,5 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version) main(args) diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 83ee170..98be80f 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -25,7 +25,6 @@ if str(lib_dir) not in sys.path: from models.xcore import get_model from datasets.synthetic_core import get_synthetic_env -from datasets.synthetic_example import create_example_v1 from utils.temp_sync import optimize_fn, evaluate_fn from procedures.metric_utils import MSEMetric @@ -214,9 +213,10 @@ def visualize_env(save_dir, version): fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() - base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir + base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( + xdir=save_dir, version=version ) + print(base_cmd) os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))