diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 7eefc0a..23a3f85 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -24,10 +24,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv -from datasets import DynamicQuadraticFunc from datasets.synthetic_example import create_example_v1 - from utils.temp_sync import optimize_fn, evaluate_fn @@ -61,43 +58,72 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): tick.label.set_rotation(10) for tick in cur_ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize - font_gap) - plt.legend(loc=1, fontsize=LegendFontsize) + cur_ax.legend(loc=1, fontsize=LegendFontsize) fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") plt.close("all") +def find_min(cur, others): + if cur is None: + return float(others.min()) + else: + return float(min(cur, others.min())) + + +def find_max(cur, others): + if cur is None: + return float(others.max()) + else: + return float(max(cur, others.max())) + + def compare_cl(save_dir): save_dir = Path(str(save_dir)) save_dir.mkdir(parents=True, exist_ok=True) - dynamic_env, function = create_example_v1(100, num_per_task=1000) + dynamic_env, function = create_example_v1( + timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), + num_per_task=1000, + ) - additional_xaxis = np.arange(-6, 6, 0.2) models = dict() cl_function = copy.deepcopy(function) cl_function.set_timestamp(0) - cl_xaxis_all = None + cl_xaxis_min = None + cl_xaxis_max = None + + all_data = OrderedDict() for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): xaxis_all = dataset[:, 0].numpy() - # xaxis_all = np.concatenate((additional_xaxis, xaxis_all)) - # compute the ground truth + current_data = dict() + function.set_timestamp(timestamp) yaxis_all = function.noise_call(xaxis_all) + current_data["lfna_xaxis_all"] = xaxis_all + current_data["lfna_yaxis_all"] = yaxis_all - # create CL data - if cl_xaxis_all is None: - cl_xaxis_all = xaxis_all - else: - cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2)) - cl_yaxis_all = cl_function(cl_xaxis_all) + import pdb + pdb.set_trace() + + # compute cl-min + cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all) + cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1 + cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) + + cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) + current_data["cl_xaxis_all"] = cl_xaxis_all + current_data["cl_yaxis_all"] = cl_yaxis_all + all_data[timestamp] = current_data + + for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): scatter_list = [] scatter_list.append( { - "xaxis": xaxis_all, - "yaxis": yaxis_all, + "xaxis": xdata["lfna_xaxis_all"], + "yaxis": xdata["lfna_yaxis_all"], "color": "k", "s": 10, "alpha": 0.99, @@ -107,6 +133,9 @@ def compare_cl(save_dir): } ) + cl_xaxis_all = current_data["cl_xaxis_all"] + cl_yaxis_all = current_data["cl_yaxis_all"] + scatter_list.append( { "xaxis": cl_xaxis_all, @@ -121,15 +150,21 @@ def compare_cl(save_dir): ) draw_multi_fig( - save_dir, timestamp, scatter_list, - wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp) + save_dir, + timestamp, + scatter_list, + wh=(2000, 1300), + fig_title="Timestamp={:03d}".format(timestamp), ) print("Save all figures into {:}".format(save_dir)) save_dir = save_dir.resolve() - cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format( - xdir=save_dir + base_cmd = ( + "ffmpeg -y -i {xdir}/%04d.png -vf fps=2 -vf scale=2000:1300 -vb 5000k".format( + xdir=save_dir + ) ) - os.system(cmd) + os.system("{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir)) + os.system("{:} -c:a libvorbis {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) if __name__ == "__main__": diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index 9d64f3b..f4193bc 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -19,7 +19,7 @@ class SyntheticDEnv(data.Dataset): mean_functors: List[data.Dataset], cov_functors: List[List[data.Dataset]], num_per_task: int = 5000, - time_stamp_config: Optional[Dict] = None, + timestamp_config: Optional[Dict] = None, mode: Optional[str] = None, ): self._ndim = len(mean_functors) @@ -31,12 +31,12 @@ class SyntheticDEnv(data.Dataset): cov_functor ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) self._num_per_task = num_per_task - if time_stamp_config is None: - time_stamp_config = dict(mode=mode) + if timestamp_config is None: + timestamp_config = dict(mode=mode) else: - time_stamp_config["mode"] = mode + timestamp_config["mode"] = mode - self._timestamp_generator = TimeStamp(**time_stamp_config) + self._timestamp_generator = TimeStamp(**timestamp_config) self._mean_functors = mean_functors self._cov_functors = cov_functors diff --git a/lib/datasets/synthetic_example.py b/lib/datasets/synthetic_example.py index 0fd780c..ffe7e67 100644 --- a/lib/datasets/synthetic_example.py +++ b/lib/datasets/synthetic_example.py @@ -2,21 +2,23 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -from .math_base_funcs import DynamicQuadraticFunc +from .math_adv_funcs import DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc, ComposedSinFunc from .synthetic_env import SyntheticDEnv -def create_example_v1(timestamps=50, num_per_task=5000): +def create_example_v1( + timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0), + num_per_task=5000, +): mean_generator = ComposedSinFunc() std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) - std_generator.set_transform(lambda x: x + 1) dynamic_env = SyntheticDEnv( [mean_generator], [[std_generator]], num_per_task=num_per_task, - time_stamp_config=dict(num=timestamps), + timestamp_config=timestamp_config, ) function = DynamicQuadraticFunc() diff --git a/scripts/black.sh b/scripts/black.sh new file mode 100644 index 0000000..10c55fc --- /dev/null +++ b/scripts/black.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# bash ./scripts/black.sh + +black ./tests/ +black ./lib/datasets +black ./lib/xlayers +black ./exps/LFNA +black ./exps/trading