Upgrade same/his

This commit is contained in:
D-X-Y 2021-04-29 04:48:21 -07:00
parent f7c2bb5e32
commit 1209fffbaa
4 changed files with 90 additions and 5 deletions

2
.gitignore vendored
View File

@ -134,3 +134,5 @@ outputs
pytest_cache
*.pkl
*.pth
*.tgz

View File

@ -136,7 +136,8 @@ def main(args):
)
save_checkpoint(
{
"model": model.state_dict(),
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},

View File

@ -132,7 +132,8 @@ def main(args):
)
save_checkpoint(
{
"model": model.state_dict(),
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},

View File

@ -213,7 +213,87 @@ def visualize_env(save_dir):
xdir=save_dir
)
os.system("{:} {xdir}/env.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/env.webm".format(base_cmd, xdir=save_dir))
def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
cache_path = Path(alg_dir) / "env-info.pth"
assert cache_path.exists(), "{:} does not exist".format(cache_path)
env_info = torch.load(cache_path)
alg_name2dir = {"Optimal": "use-same-timestamp", "History SL": "use-all-past-data"}
colors = ["r", "g"]
dynamic_env = env_info["dynamic_env"]
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx == 0:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
cur_ax.scatter(
allx,
ally,
color="k",
alpha=0.99,
s=10,
label=None,
)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = (
Path(alg_dir)
/ xdir
/ "{:04d}-{:04d}.pth".format(idx, env_info["total"])
)
assert ckp_path.exists()
ckp_data = torch.load(ckp_path)
with torch.no_grad():
predicts = ckp_data["model"](ori_allx)
predicts = predicts.cpu().view(-1).numpy()
cur_ax.scatter(
allx,
predicts,
color=colors[idx_alg],
alpha=0.99,
s=20,
label=alg,
)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
save_path = save_dir / "{:05d}".format(idx)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, w=width, h=height
)
os.system("{:} {xdir}/compare_alg.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/compare_alg.webm".format(base_cmd, xdir=save_dir))
# the trajectory data
if __name__ == "__main__":
@ -227,5 +307,6 @@ if __name__ == "__main__":
)
args = parser.parse_args()
visualize_env(os.path.join(args.save_dir, "vis-env"))
compare_cl(os.path.join(args.save_dir, "compare-cl"))
compare_algs(os.path.join(args.save_dir, "compare-alg"))
# visualize_env(os.path.join(args.save_dir, "vis-env"))
# compare_cl(os.path.join(args.save_dir, "compare-cl"))