diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 60b5abe..395eb0f 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ############################################################################ -# CUDA_VISIBLE_DEVICES=0 python exps/LFNA/vis-synthetic.py # +# python exps/LFNA/vis-synthetic.py # ############################################################################ import os, sys, copy, random import torch @@ -83,7 +83,7 @@ def find_max(cur, others): def compare_cl(save_dir): save_dir = Path(str(save_dir)) save_dir.mkdir(parents=True, exist_ok=True) - dynamic_env, function = create_example_v1( + dynamic_env, cl_function = create_example_v1( # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), timestamp_config=dict(num=200), num_per_task=1000, @@ -91,7 +91,6 @@ def compare_cl(save_dir): models = dict() - cl_function = copy.deepcopy(function) cl_function.set_timestamp(0) cl_xaxis_min = None cl_xaxis_max = None @@ -99,23 +98,15 @@ def compare_cl(save_dir): all_data = OrderedDict() for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): - xaxis_all = dataset[:, 0].numpy() + xaxis_all = dataset[0][:, 0].numpy() + yaxis_all = dataset[1][:, 0].numpy() current_data = dict() - - function.set_timestamp(timestamp) - yaxis_all = function.noise_call(xaxis_all) current_data["lfna_xaxis_all"] = xaxis_all current_data["lfna_yaxis_all"] = yaxis_all # compute cl-min cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) - """ - cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) - cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) - current_data["cl_xaxis_all"] = cl_xaxis_all - current_data["cl_yaxis_all"] = cl_yaxis_all - """ all_data[timestamp] = current_data global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) @@ -170,10 +161,12 @@ def compare_cl(save_dir): xdir=save_dir ) ) - video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(base_cmd, xdir=save_dir) + video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( + base_cmd, xdir=save_dir + ) print(video_cmd + "\n") os.system(video_cmd) - # os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) + os.system("{:} -pix_fmt yuv420p {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) if __name__ == "__main__": diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index 62e79cc..6ed35dd 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -5,7 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .SearchDatasetWrap import SearchDataset from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc -from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc +from .math_dynamic_funcs import DynamicQuadraticFunc +from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ComposedSinFunc from .synthetic_utils import TimeStamp diff --git a/lib/datasets/math_adv_funcs.py b/lib/datasets/math_adv_funcs.py index 4315258..d84a5e0 100644 --- a/lib/datasets/math_adv_funcs.py +++ b/lib/datasets/math_adv_funcs.py @@ -14,41 +14,6 @@ from .math_base_funcs import QuadraticFunc from .math_base_funcs import QuarticFunc -class DynamicQuadraticFunc(FitFunc): - """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. - The a, b, and c is a function of timestamp. - """ - - def __init__(self, list_of_points=None): - super(DynamicQuadraticFunc, self).__init__(3, list_of_points) - self._timestamp = None - - def __call__(self, x, timestamp=None): - self.check_valid() - if timestamp is None: - timestamp = self._timestamp - a = self._params[0](timestamp) - b = self._params[1](timestamp) - c = self._params[2](timestamp) - convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x - a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) - return a * x * x + b * x + c - - def _getitem(self, x, weights): - raise NotImplementedError - - def set_timestamp(self, timestamp): - self._timestamp = timestamp - - def __repr__(self): - return "{name}({a} * x^2 + {b} * x + {c})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - ) - - class ConstantFunc(FitFunc): """The constant function: f(x) = c.""" diff --git a/lib/datasets/math_base_funcs.py b/lib/datasets/math_base_funcs.py index cab66a2..42a4bd4 100644 --- a/lib/datasets/math_base_funcs.py +++ b/lib/datasets/math_base_funcs.py @@ -13,20 +13,20 @@ import torch.utils.data as data class FitFunc(abc.ABC): """The fit function that outputs f(x) = a * x^2 + b * x + c.""" - def __init__(self, freedom: int, list_of_points=None, _params=None): + def __init__(self, freedom: int, list_of_points=None, params=None): self._params = dict() for i in range(freedom): self._params[i] = None self._freedom = freedom - if list_of_points is not None and _params is not None: - raise ValueError("list_of_points and _params can not be set simultaneously") + if list_of_points is not None and params is not None: + raise ValueError("list_of_points and params can not be set simultaneously") if list_of_points is not None: self.fit(list_of_points=list_of_points) - if _params is not None: - self.set(_params) + if params is not None: + self.set(params) - def set(self, _params): - self._params = copy.deepcopy(_params) + def set(self, params): + self._params = copy.deepcopy(params) def check_valid(self): for key, value in self._params.items(): diff --git a/lib/datasets/math_dynamic_funcs.py b/lib/datasets/math_dynamic_funcs.py new file mode 100644 index 0000000..0a86716 --- /dev/null +++ b/lib/datasets/math_dynamic_funcs.py @@ -0,0 +1,66 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import math +import abc +import copy +import numpy as np +from typing import Optional +import torch +import torch.utils.data as data + +from .math_base_funcs import FitFunc + + +class DynamicFunc(FitFunc): + """The dynamic quadratic function, where each param is a function.""" + + def __init__(self, freedom: int, params=None): + super(DynamicFunc, self).__init__(freedom, None, params) + self._timestamp = None + + def __call__(self, x, timestamp=None): + raise NotImplementedError + + def _getitem(self, x, weights): + raise NotImplementedError + + def set_timestamp(self, timestamp): + self._timestamp = timestamp + + def noise_call(self, x, timestamp=None, std=0.1): + clean_y = self.__call__(x, timestamp) + if isinstance(clean_y, np.ndarray): + noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) + else: + raise ValueError("Unkonwn type: {:}".format(type(clean_y))) + return noise_y + + +class DynamicQuadraticFunc(DynamicFunc): + """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. + The a, b, and c is a function of timestamp. + """ + + def __init__(self, params=None): + super(DynamicQuadraticFunc, self).__init__(3, params) + + def __call__(self, x, timestamp=None): + self.check_valid() + if timestamp is None: + timestamp = self._timestamp + a = self._params[0](timestamp) + b = self._params[1](timestamp) + c = self._params[2](timestamp) + convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x + a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) + return a * x * x + b * x + c + + def __repr__(self): + return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + timestamp=self._timestamp, + ) diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index fa52dc3..ad73d6b 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset): self._mean_functors = mean_functors self._cov_functors = cov_functors + self._oracle_map = None + + def set_oracle_map(self, functor): + self._oracle_map = functor + def __iter__(self): self._iter_num = 0 return self @@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset): dataset = np.random.multivariate_normal( mean_list, cov_matrix, size=self._num_per_task ) - return timestamp, torch.Tensor(dataset) + if self._oracle_map is None: + return timestamp, torch.Tensor(dataset) + else: + targets = self._oracle_map.noise_call(dataset, timestamp) + return timestamp, (torch.Tensor(dataset), torch.Tensor(targets)) def __len__(self): return len(self._timestamp_generator) diff --git a/lib/datasets/synthetic_example.py b/lib/datasets/synthetic_example.py index 40e917f..f72f15c 100644 --- a/lib/datasets/synthetic_example.py +++ b/lib/datasets/synthetic_example.py @@ -1,8 +1,9 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### +import copy -from .math_adv_funcs import DynamicQuadraticFunc +from .math_dynamic_funcs import DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc, ComposedSinFunc from .synthetic_env import SyntheticDEnv @@ -11,7 +12,6 @@ def create_example_v1( timestamp_config=None, num_per_task=5000, ): - # timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0), mean_generator = ComposedSinFunc() std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) @@ -32,4 +32,6 @@ def create_example_v1( num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 ) function.set(function_param) + + dynamic_env.set_oracle_map(copy.deepcopy(function)) return dynamic_env, function diff --git a/scripts/black.sh b/scripts/black.sh index 10c55fc..ba1a10e 100644 --- a/scripts/black.sh +++ b/scripts/black.sh @@ -6,3 +6,4 @@ black ./lib/datasets black ./lib/xlayers black ./exps/LFNA black ./exps/trading +black ./lib/procedures