Updates
This commit is contained in:
parent
198771b78d
commit
a5b7d986b3
112
exps/synthetic/baseline.py
Normal file
112
exps/synthetic/baseline.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||||
|
#####################################################
|
||||||
|
# python exps/synthetic/baseline.py #
|
||||||
|
#####################################################
|
||||||
|
import os, sys, copy
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
from matplotlib import cm
|
||||||
|
|
||||||
|
matplotlib.use("agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.ticker as ticker
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
|
||||||
|
from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
|
||||||
|
from datasets import DynamicQuadraticFunc
|
||||||
|
from datasets.synthetic_example import create_example_v1
|
||||||
|
|
||||||
|
|
||||||
|
def draw_fig(save_dir, timestamp, scatter_list):
|
||||||
|
save_path = save_dir / "{:04d}".format(timestamp)
|
||||||
|
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
|
||||||
|
dpi, width, height = 40, 1500, 1500
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
LabelSize, LegendFontsize, font_gap = 80, 80, 5
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
|
||||||
|
cur_ax = fig.add_subplot(1, 1, 1)
|
||||||
|
for scatter_dict in scatter_list:
|
||||||
|
cur_ax.scatter(
|
||||||
|
scatter_dict["xaxis"],
|
||||||
|
scatter_dict["yaxis"],
|
||||||
|
color=scatter_dict["color"],
|
||||||
|
s=scatter_dict["s"],
|
||||||
|
alpha=scatter_dict["alpha"],
|
||||||
|
label=scatter_dict["label"],
|
||||||
|
)
|
||||||
|
cur_ax.set_xlabel("X", fontsize=LabelSize)
|
||||||
|
cur_ax.set_ylabel("f(X)", rotation=0, fontsize=LabelSize)
|
||||||
|
cur_ax.set_xlim(-6, 6)
|
||||||
|
cur_ax.set_ylim(-40, 40)
|
||||||
|
for tick in cur_ax.xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
|
tick.label.set_rotation(10)
|
||||||
|
for tick in cur_ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
|
|
||||||
|
plt.legend(loc=1, fontsize=LegendFontsize)
|
||||||
|
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
||||||
|
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
||||||
|
plt.close("all")
|
||||||
|
|
||||||
|
|
||||||
|
def main(save_dir):
|
||||||
|
save_dir = Path(str(save_dir))
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dynamic_env, function = create_example_v1(100, num_per_task=500)
|
||||||
|
|
||||||
|
additional_xaxis = np.arange(-6, 6, 0.1)
|
||||||
|
for timestamp, dataset in tqdm(dynamic_env, ncols=50):
|
||||||
|
num = dataset.shape[0]
|
||||||
|
xaxis = dataset[:, 0].numpy()
|
||||||
|
# compute the ground truth
|
||||||
|
function.set_timestamp(timestamp)
|
||||||
|
yaxis = function(xaxis)
|
||||||
|
# xaxis = np.concatenate((additional_xaxis, xaxis))
|
||||||
|
# the first plot
|
||||||
|
scatter_list = []
|
||||||
|
scatter_list.append(
|
||||||
|
{
|
||||||
|
"xaxis": xaxis,
|
||||||
|
"yaxis": yaxis,
|
||||||
|
"color": "k",
|
||||||
|
"s": 10,
|
||||||
|
"alpha": 0.99,
|
||||||
|
"label": "Timestamp={:02d}".format(timestamp),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
draw_fig(save_dir, timestamp, scatter_list)
|
||||||
|
print("Save all figures into {:}".format(save_dir))
|
||||||
|
save_dir = save_dir.resolve()
|
||||||
|
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {xdir}/vis.mp4".format(
|
||||||
|
xdir=save_dir
|
||||||
|
)
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser("Baseline")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default="./outputs/vis-synthetic",
|
||||||
|
help="The save directory.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.save_dir)
|
Loading…
Reference in New Issue
Block a user