Update synthetic
This commit is contained in:
parent
9168c62855
commit
6e7b1c551f
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/LFNA/basic-his.py --srange 1-999
|
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -36,12 +36,14 @@ def main(args):
|
|||||||
prepare_seed(args.rand_seed)
|
prepare_seed(args.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
cache_path = (logger.path(None) / ".." / "env-info.pth").resolve()
|
cache_path = (
|
||||||
|
logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version)
|
||||||
|
).resolve()
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
env_info = torch.load(cache_path)
|
env_info = torch.load(cache_path)
|
||||||
else:
|
else:
|
||||||
env_info = dict()
|
env_info = dict()
|
||||||
dynamic_env = get_synthetic_env()
|
dynamic_env = get_synthetic_env(version=args.env_version)
|
||||||
env_info["total"] = len(dynamic_env)
|
env_info["total"] = len(dynamic_env)
|
||||||
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
|
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
|
||||||
env_info["{:}-timestamp".format(idx)] = timestamp
|
env_info["{:}-timestamp".format(idx)] = timestamp
|
||||||
@ -169,6 +171,12 @@ if __name__ == "__main__":
|
|||||||
default="./outputs/lfna-synthetic/use-all-past-data",
|
default="./outputs/lfna-synthetic/use-all-past-data",
|
||||||
help="The checkpoint directory.",
|
help="The checkpoint directory.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env_version",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The synthetic enviornment version.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init_lr",
|
"--init_lr",
|
||||||
type=float,
|
type=float,
|
||||||
@ -202,4 +210,5 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
|
args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version)
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -25,7 +25,6 @@ if str(lib_dir) not in sys.path:
|
|||||||
|
|
||||||
from models.xcore import get_model
|
from models.xcore import get_model
|
||||||
from datasets.synthetic_core import get_synthetic_env
|
from datasets.synthetic_core import get_synthetic_env
|
||||||
from datasets.synthetic_example import create_example_v1
|
|
||||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||||
from procedures.metric_utils import MSEMetric
|
from procedures.metric_utils import MSEMetric
|
||||||
|
|
||||||
@ -214,9 +213,10 @@ def visualize_env(save_dir, version):
|
|||||||
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
|
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
|
||||||
xdir=save_dir
|
xdir=save_dir, version=version
|
||||||
)
|
)
|
||||||
|
print(base_cmd)
|
||||||
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
|
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
|
||||||
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
|
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user