Update vis codes

This commit is contained in:
D-X-Y 2021-04-29 23:37:50 +08:00
parent 1209fffbaa
commit 3117d4f5f5

View File

@ -7,7 +7,7 @@ import os, sys, copy, random
import torch
import numpy as np
import argparse
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
@ -27,6 +27,12 @@ if str(lib_dir) not in sys.path:
from datasets.synthetic_core import get_synthetic_env
from datasets.synthetic_example import create_example_v1
from utils.temp_sync import optimize_fn, evaluate_fn
from procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
@ -44,16 +50,17 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
for idx, scatter_dict in enumerate(scatter_list):
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
cur_ax.scatter(
plot_scatter(
cur_ax,
scatter_dict["xaxis"],
scatter_dict["yaxis"],
color=scatter_dict["color"],
s=scatter_dict["s"],
alpha=scatter_dict["alpha"],
label=scatter_dict["label"],
scatter_dict["color"],
scatter_dict["alpha"],
scatter_dict["linewidths"],
scatter_dict["label"],
)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
for tick in cur_ax.xaxis.get_major_ticks():
@ -120,7 +127,7 @@ def compare_cl(save_dir):
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
"color": "k",
"s": 12,
"linewidths": 15,
"alpha": 0.99,
"xlim": (-6, 6),
"ylim": (-40, 40),
@ -140,7 +147,7 @@ def compare_cl(save_dir):
"xaxis": cl_xaxis_all,
"yaxis": cl_yaxis_all,
"color": "k",
"s": 12,
"linewidths": 15,
"xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)),
"ylim": (-20, 6),
"alpha": 0.99,
@ -167,7 +174,7 @@ def compare_cl(save_dir):
)
print(video_cmd + "\n")
os.system(video_cmd)
os.system("{:} -pix_fmt yuv420p {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
os.system("{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir))
def visualize_env(save_dir):
@ -184,15 +191,7 @@ def visualize_env(save_dir):
cur_ax = fig.add_subplot(1, 1, 1)
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
cur_ax.scatter(
allx,
ally,
color="k",
linestyle="-",
alpha=0.99,
s=10,
label="timestamp={:05d}".format(idx),
)
plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
@ -228,11 +227,15 @@ def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"):
assert cache_path.exists(), "{:} does not exist".format(cache_path)
env_info = torch.load(cache_path)
alg_name2dir = {"Optimal": "use-same-timestamp", "History SL": "use-all-past-data"}
alg_name2dir = OrderedDict()
alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir["History SL"] = "use-all-past-data"
colors = ["r", "g"]
dynamic_env = env_info["dynamic_env"]
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
linewidths = 10
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
@ -243,14 +246,7 @@ def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"):
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
cur_ax.scatter(
allx,
ally,
color="k",
alpha=0.99,
s=10,
label=None,
)
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = (
@ -263,14 +259,7 @@ def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"):
with torch.no_grad():
predicts = ckp_data["model"](ori_allx)
predicts = predicts.cpu().view(-1).numpy()
cur_ax.scatter(
allx,
predicts,
color=colors[idx_alg],
alpha=0.99,
s=20,
label=alg,
)
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
@ -291,9 +280,105 @@ def compare_algs(save_dir, alg_dir="./outputs/lfna-synthetic"):
base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, w=width, h=height
)
os.system("{:} {xdir}/compare_alg.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/compare_alg.webm".format(base_cmd, xdir=save_dir))
# the trajectory data
os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir))
def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
cache_path = Path(alg_dir) / "env-info.pth"
assert cache_path.exists(), "{:} does not exist".format(cache_path)
env_info = torch.load(cache_path)
alg_name2dir = OrderedDict()
alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir["History SL"] = "use-all-past-data"
colors = ["r", "g"]
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
colors = ["r", "g"]
dynamic_env = env_info["dynamic_env"]
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
linewidths = 10
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx == 0:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(2, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = (
Path(alg_dir)
/ xdir
/ "{:04d}-{:04d}.pth".format(idx, env_info["total"])
)
assert ckp_path.exists()
ckp_data = torch.load(ckp_path)
with torch.no_grad():
predicts = ckp_data["model"](ori_allx)
predicts = predicts.cpu()
# keep data
metric = MSEMetric()
metric(predicts, ori_ally)
predicts = predicts.view(-1).numpy()
alg2xs[alg].append(idx)
alg2ys[alg].append(metric.get_info()['mse'])
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
# the trajectory data
cur_ax = fig.add_subplot(2, 1, 2)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
cur_ax.plot(alg2xs[alg], alg2ys[alg], color=colors[idx_alg], linestyle='-', linewidth=5, label=alg)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
cur_ax.set_ylabel("MSE", fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(1, len(dynamic_env))
cur_ax.set_ylim(0, 10)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
save_path = save_dir / "{:05d}".format(idx)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, w=width, h=height
)
os.system("{:} {xdir}/compare-alg.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} {xdir}/compare-alg.webm".format(base_cmd, xdir=save_dir))
if __name__ == "__main__":
@ -307,6 +392,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
compare_algs(os.path.join(args.save_dir, "compare-alg"))
compare_algs_v2(os.path.join(args.save_dir, "compare-alg-v2"))
# visualize_env(os.path.join(args.save_dir, "vis-env"))
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
# compare_algs(os.path.join(args.save_dir, "compare-alg"))