Update synthetic
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-his.py --srange 1-999 | # python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -36,12 +36,14 @@ def main(args): | |||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() |     cache_path = ( | ||||||
|  |         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||||
|  |     ).resolve() | ||||||
|     if cache_path.exists(): |     if cache_path.exists(): | ||||||
|         env_info = torch.load(cache_path) |         env_info = torch.load(cache_path) | ||||||
|     else: |     else: | ||||||
|         env_info = dict() |         env_info = dict() | ||||||
|         dynamic_env = get_synthetic_env() |         dynamic_env = get_synthetic_env(version=args.env_version) | ||||||
|         env_info["total"] = len(dynamic_env) |         env_info["total"] = len(dynamic_env) | ||||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|             env_info["{:}-timestamp".format(idx)] = timestamp |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
| @@ -169,6 +171,12 @@ if __name__ == "__main__": | |||||||
|         default="./outputs/lfna-synthetic/use-all-past-data", |         default="./outputs/lfna-synthetic/use-all-past-data", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--env_version", | ||||||
|  |         type=str, | ||||||
|  |         required=True, | ||||||
|  |         help="The synthetic enviornment version.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|         type=float, |         type=float, | ||||||
| @@ -202,4 +210,5 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
| @@ -25,7 +25,6 @@ if str(lib_dir) not in sys.path: | |||||||
|  |  | ||||||
| from models.xcore import get_model | from models.xcore import get_model | ||||||
| from datasets.synthetic_core import get_synthetic_env | from datasets.synthetic_core import get_synthetic_env | ||||||
| from datasets.synthetic_example import create_example_v1 |  | ||||||
| from utils.temp_sync import optimize_fn, evaluate_fn | from utils.temp_sync import optimize_fn, evaluate_fn | ||||||
| from procedures.metric_utils import MSEMetric | from procedures.metric_utils import MSEMetric | ||||||
|  |  | ||||||
| @@ -214,9 +213,10 @@ def visualize_env(save_dir, version): | |||||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||||
|         plt.close("all") |         plt.close("all") | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( |     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||||
|         xdir=save_dir |         xdir=save_dir, version=version | ||||||
|     ) |     ) | ||||||
|  |     print(base_cmd) | ||||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user