Update the sync data v1
This commit is contained in:
		
							
								
								
									
										6
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -56,12 +56,16 @@ jobs: | ||||
|           python -m pytest ./tests/test_basic_space.py -s | ||||
|         shell: bash | ||||
|  | ||||
|       - name: Test Synthetic Data | ||||
|       - name: Test Math | ||||
|         run: | | ||||
|           python -m pip install pytest numpy | ||||
|           python -m pip install parameterized | ||||
|           python -m pip install torch torchvision | ||||
|           python --version | ||||
|           python -m pytest ./tests/test_math*.py -s | ||||
|         shell: bash | ||||
|  | ||||
|       - name: Test Synthetic Data | ||||
|         run: | | ||||
|           python -m pytest ./tests/test_synthetic*.py -s | ||||
|         shell: bash | ||||
|   | ||||
| @@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     logger, model_kwargs = lfna_setup(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||
|   | ||||
| @@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| def lfna_setup(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = ( | ||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||
|     ).resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env(version=args.env_version) | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|  | ||||
|     """ | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="simple_mlp"), | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dim=args.hidden_dim, | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|     ) | ||||
|     """ | ||||
|     model_kwargs = dict( | ||||
|         config=dict(model_type="norm_mlp"), | ||||
|         input_dim=1, | ||||
| @@ -46,7 +19,7 @@ def lfna_setup(args): | ||||
|         act_cls="gelu", | ||||
|         norm_cls="layer_norm_1d", | ||||
|     ) | ||||
|     return logger, env_info, model_kwargs | ||||
|     return logger, model_kwargs | ||||
|  | ||||
|  | ||||
| def train_model(model, dataset, lr, epochs): | ||||
|   | ||||
| @@ -20,14 +20,13 @@ matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from models.xcore import get_model | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from utils.temp_sync import optimize_fn, evaluate_fn | ||||
| from procedures.metric_utils import MSEMetric | ||||
| from xautodl.models.xcore import get_model | ||||
| from xautodl.datasets.synthetic_core import get_synthetic_env | ||||
| from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
| @@ -181,10 +180,17 @@ def compare_cl(save_dir): | ||||
|  | ||||
| def visualize_env(save_dir, version): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
|     # min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
| @@ -201,21 +207,18 @@ def visualize_env(save_dir, version): | ||||
|             tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         if version == "v1": | ||||
|             cur_ax.set_xlim(-2, 2) | ||||
|             cur_ax.set_ylim(-8, 8) | ||||
|         elif version == "v2": | ||||
|             cur_ax.set_xlim(-10, 10) | ||||
|             cur_ax.set_ylim(-60, 60) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||
|  | ||||
|         save_path = save_dir / "v{:}-{:05d}".format(version, idx) | ||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") | ||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||
|         plt.close("all") | ||||
|     save_dir = save_dir.resolve() | ||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||
|         xdir=save_dir, version=version | ||||
|         xdir=save_dir / "png", version=version | ||||
|     ) | ||||
|     print(base_cmd) | ||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||
| @@ -371,7 +374,7 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||
|   | ||||
| @@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config | ||||
|  | ||||
| # NAS-Bench-201 related module or function | ||||
| from xautodl.models import CellStructure, get_cell_based_tiny_net | ||||
| from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders | ||||
| from xautodl.procedures import ( | ||||
|     bench_pure_evaluate as pure_evaluate, | ||||
|     get_nas_bench_loaders, | ||||
| ) | ||||
| from nas_201_api import NASBench201API, ArchResults, ResultsCount | ||||
|  | ||||
| api = NASBench201API( | ||||
|   | ||||
							
								
								
									
										21
									
								
								exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| # python test-dynamic.py | ||||
| ##################################################### | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.datasets.math_core import ConstantFunc | ||||
| from xautodl.datasets.math_core import GaussianDGenerator | ||||
|  | ||||
| mean_generator = ConstantFunc(0) | ||||
| cov_generator = ConstantFunc(1) | ||||
|  | ||||
| generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1)) | ||||
| generator(0, 10) | ||||
| @@ -19,9 +19,11 @@ import seaborn as sns | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from log_utils import time_string | ||||
| from nats_bench import create | ||||
| from models import get_cell_based_tiny_net | ||||
|   | ||||
| @@ -3,11 +3,7 @@ from copy import deepcopy | ||||
| import torchvision.models as models | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from utils import weight_watcher | ||||
| from xautodl.utils import weight_watcher | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|   | ||||
| @@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc | ||||
| class ConstantFunc(FitFunc): | ||||
|     """The constant function: f(x) = c.""" | ||||
|  | ||||
|     def __init__(self, constant=None): | ||||
|     def __init__(self, constant=None, xstr="x"): | ||||
|         param = dict() | ||||
|         param[0] = constant | ||||
|         super(ConstantFunc, self).__init__(0, None, param) | ||||
|         super(ConstantFunc, self).__init__(0, None, param, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
| @@ -37,6 +37,34 @@ class ConstantFunc(FitFunc): | ||||
|  | ||||
|  | ||||
| class ComposedSinFunc(FitFunc): | ||||
|     """The composed sin function that outputs: | ||||
|     f(x) = a * sin( b*x ) + c | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(ComposedSinFunc, self).__init__(3, None, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         a = self._params[0] | ||||
|         b = self._params[1] | ||||
|         c = self._params[2] | ||||
|         return a * math.sin(b * x) + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ComposedSinFuncV2(FitFunc): | ||||
|     """The composed sin function that outputs: | ||||
|       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) | ||||
|     - the amplitude scale is a quadratic function of x | ||||
| @@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc): | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         super(ComposedSinFunc, self).__init__(0, None) | ||||
|         super(ComposedSinFuncV2, self).__init__(0, None) | ||||
|         self.fit(**kwargs) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|   | ||||
| @@ -5,15 +5,13 @@ import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class FitFunc(abc.ABC): | ||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, list_of_points=None, params=None): | ||||
|     def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): | ||||
|         self._params = dict() | ||||
|         for i in range(freedom): | ||||
|             self._params[i] = None | ||||
| @@ -24,6 +22,7 @@ class FitFunc(abc.ABC): | ||||
|             self.fit(list_of_points=list_of_points) | ||||
|         if params is not None: | ||||
|             self.set(params) | ||||
|         self._xstr = str(xstr) | ||||
|  | ||||
|     def set(self, params): | ||||
|         self._params = copy.deepcopy(params) | ||||
| @@ -33,6 +32,13 @@ class FitFunc(abc.ABC): | ||||
|             if value is None: | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     @property | ||||
|     def xstr(self): | ||||
|         return self._xstr | ||||
|  | ||||
|     def reset_xstr(self, xstr): | ||||
|         self._xstr = str(xstr) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
| @@ -106,8 +112,8 @@ class FitFunc(abc.ABC): | ||||
| class LinearFunc(FitFunc): | ||||
|     """The linear function that outputs f(x) = a * x + b.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(LinearFunc, self).__init__(2, list_of_points, params) | ||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||
|         super(LinearFunc, self).__init__(2, list_of_points, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
| @@ -117,18 +123,19 @@ class LinearFunc(FitFunc): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b})".format( | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticFunc(FitFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params) | ||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
| @@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -165,12 +173,13 @@ class CubicFunc(FitFunc): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( | ||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
| from .math_dynamic_generator import GaussianDGenerator | ||||
|   | ||||
| @@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc | ||||
| class DynamicFunc(FitFunc): | ||||
|     """The dynamic quadratic function, where each param is a function.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, params=None): | ||||
|         super(DynamicFunc, self).__init__(freedom, None, params) | ||||
|         self._timestamp = None | ||||
|     def __init__(self, freedom: int, params=None, xstr="x"): | ||||
|         if params is not None: | ||||
|             for param in params: | ||||
|                 param.reset_xstr("t") if isinstance(param, FitFunc) else None | ||||
|         super(DynamicFunc, self).__init__(freedom, None, params, xstr) | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|     def __call__(self, x, timestamp): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def set_timestamp(self, timestamp): | ||||
|         self._timestamp = timestamp | ||||
|  | ||||
|     def noise_call(self, x, timestamp=None, std=0.1): | ||||
|     def noise_call(self, x, timestamp, std): | ||||
|         clean_y = self.__call__(x, timestamp) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
| @@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc): | ||||
|     The a and b is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicLinearFunc, self).__init__(3, params) | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(DynamicLinearFunc, self).__init__(3, params, xstr) | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|         self.check_valid() | ||||
|         if timestamp is None: | ||||
|             timestamp = self._timestamp | ||||
|     def __call__(self, x, timestamp): | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
| @@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc): | ||||
|         return a * x + b | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b}, timestamp={timestamp})".format( | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             timestamp=self._timestamp, | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										58
									
								
								xautodl/datasets/math_dynamic_generator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								xautodl/datasets/math_dynamic_generator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import abc | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def assert_list_tuple(x): | ||||
|     assert isinstance(x, (list, tuple)) | ||||
|     return len(x) | ||||
|  | ||||
|  | ||||
| class DynamicGenerator(abc.ABC): | ||||
|     """The dynamic quadratic function, where each param is a function.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._ndim = None | ||||
|  | ||||
|     def __call__(self, time, num): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class GaussianDGenerator(DynamicGenerator): | ||||
|     def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): | ||||
|         super(GaussianDGenerator, self).__init__() | ||||
|         self._ndim = assert_list_tuple(mean_functors) | ||||
|         assert self._ndim == len( | ||||
|             cov_functors | ||||
|         ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) | ||||
|         assert_list_tuple(cov_functors) | ||||
|         for cov_functor in cov_functors: | ||||
|             assert self._ndim == assert_list_tuple( | ||||
|                 cov_functor | ||||
|             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) | ||||
|         assert ( | ||||
|             isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1] | ||||
|         ) | ||||
|         self._mean_functors = mean_functors | ||||
|         self._cov_functors = cov_functors | ||||
|         if trunc is not None: | ||||
|             assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] | ||||
|         self._trunc = trunc | ||||
|  | ||||
|     def __call__(self, time, num): | ||||
|         mean_list = [functor(time) for functor in self._mean_functors] | ||||
|         cov_matrix = [ | ||||
|             [abs(cov_gen(time)) for cov_gen in cov_functor] | ||||
|             for cov_functor in self._cov_functors | ||||
|         ] | ||||
|         values = np.random.multivariate_normal(mean_list, cov_matrix, size=num) | ||||
|         if self._trunc is not None: | ||||
|             np.clip(values, self._trunc[0], self._trunc[1], out=values) | ||||
|         return values | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({ndim} dims, trunc={trunc})".format( | ||||
|             name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc | ||||
|         ) | ||||
| @@ -1,13 +1,14 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| import math | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import EnvSampler | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
| from .math_core import DynamicQuadraticFunc | ||||
| from .math_core import ConstantFunc, ComposedSinFunc | ||||
| from .math_core import GaussianDGenerator | ||||
|  | ||||
|  | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
| @@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio | ||||
|     if version == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|     elif version == "v2": | ||||
|         mean_generator = ComposedSinFunc() | ||||
|         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|         data_generator = GaussianDGenerator( | ||||
|             [mean_generator], [[std_generator]], (-2, 2) | ||||
|         ) | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicLinearFunc( | ||||
|             params={ | ||||
|                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), | ||||
|                 1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), | ||||
|             } | ||||
|         ) | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=dict( | ||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode | ||||
|         ), | ||||
|     ) | ||||
|     if version == "v1": | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             amplitude_scale=ConstantFunc(3.0), | ||||
|             num_sin_phase=9, | ||||
|             sin_speed_use_power=False, | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|     elif version == "v2": | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|             data_generator, oracle_map, time_generator, num_per_task | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|  | ||||
|     function.set(function_param) | ||||
|     # dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     dynamic_env.set_oracle_map(function) | ||||
|     return dynamic_env | ||||
|   | ||||
| @@ -1,15 +1,9 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import math | ||||
| import random | ||||
| import numpy as np | ||||
| from typing import List, Optional, Dict | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
| from .synthetic_utils import TimeStamp | ||||
|  | ||||
|  | ||||
| def is_list_tuple(x): | ||||
|     return isinstance(x, (tuple, list)) | ||||
| @@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         mean_functors: List[data.Dataset], | ||||
|         cov_functors: List[List[data.Dataset]], | ||||
|         data_generator, | ||||
|         oracle_map, | ||||
|         time_generator, | ||||
|         num_per_task: int = 5000, | ||||
|         timestamp_config: Optional[Dict] = None, | ||||
|         mode: Optional[str] = None, | ||||
|         timestamp_noise_scale: float = 0.3, | ||||
|         noise: float = 0.1, | ||||
|     ): | ||||
|         self._ndim = len(mean_functors) | ||||
|         assert self._ndim == len( | ||||
|             cov_functors | ||||
|         ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) | ||||
|         for cov_functor in cov_functors: | ||||
|             assert self._ndim == len( | ||||
|                 cov_functor | ||||
|             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) | ||||
|         self._data_generator = data_generator | ||||
|         self._time_generator = time_generator | ||||
|         self._oracle_map = oracle_map | ||||
|         self._num_per_task = num_per_task | ||||
|         if timestamp_config is None: | ||||
|             timestamp_config = dict(mode=mode) | ||||
|         elif "mode" not in timestamp_config: | ||||
|             timestamp_config["mode"] = mode | ||||
|  | ||||
|         self._timestamp_generator = TimeStamp(**timestamp_config) | ||||
|         self._timestamp_noise_scale = timestamp_noise_scale | ||||
|  | ||||
|         self._mean_functors = mean_functors | ||||
|         self._cov_functors = cov_functors | ||||
|  | ||||
|         self._oracle_map = None | ||||
|         self._noise = noise | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
|         return self._timestamp_generator.min_timestamp | ||||
|         return self._time_generator.min_timestamp | ||||
|  | ||||
|     @property | ||||
|     def max_timestamp(self): | ||||
|         return self._timestamp_generator.max_timestamp | ||||
|         return self._time_generator.max_timestamp | ||||
|  | ||||
|     @property | ||||
|     def timestamp_interval(self): | ||||
|         return self._timestamp_generator.interval | ||||
|     def time_interval(self): | ||||
|         return self._time_generator.interval | ||||
|  | ||||
|     @property | ||||
|     def mode(self): | ||||
|         return self._time_generator.mode | ||||
|  | ||||
|     def random_timestamp(self, min_timestamp=None, max_timestamp=None): | ||||
|         if min_timestamp is None: | ||||
| @@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def get_timestamp(self, index): | ||||
|         if index is None: | ||||
|             timestamps = [] | ||||
|             for index in range(len(self._timestamp_generator)): | ||||
|                 timestamps.append(self._timestamp_generator[index][1]) | ||||
|             for index in range(len(self._time_generator)): | ||||
|                 timestamps.append(self._time_generator[index][1]) | ||||
|             return tuple(timestamps) | ||||
|         else: | ||||
|             index, timestamp = self._timestamp_generator[index] | ||||
|             index, timestamp = self._time_generator[index] | ||||
|             return timestamp | ||||
|  | ||||
|     def set_oracle_map(self, functor): | ||||
|         self._oracle_map = functor | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
|         return self | ||||
| @@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset): | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index, timestamp = self._timestamp_generator[index] | ||||
|         index, timestamp = self._time_generator[index] | ||||
|         return self.__call__(timestamp) | ||||
|  | ||||
|     def seq_call(self, timestamps): | ||||
| @@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset): | ||||
|             return zip_sequence(xdata) | ||||
|  | ||||
|     def __call__(self, timestamp): | ||||
|         mean_list = [functor(timestamp) for functor in self._mean_functors] | ||||
|         cov_matrix = [ | ||||
|             [abs(cov_gen(timestamp)) for cov_gen in cov_functor] | ||||
|             for cov_functor in self._cov_functors | ||||
|         ] | ||||
|  | ||||
|         dataset = np.random.multivariate_normal( | ||||
|             mean_list, cov_matrix, size=self._num_per_task | ||||
|         ) | ||||
|         if self._oracle_map is None: | ||||
|             return torch.Tensor([timestamp]), torch.Tensor(dataset) | ||||
|         else: | ||||
|             targets = self._oracle_map.noise_call(dataset, timestamp) | ||||
|         dataset = self._data_generator(timestamp, self._num_per_task) | ||||
|         targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) | ||||
|         return torch.Tensor([timestamp]), ( | ||||
|             torch.Tensor(dataset), | ||||
|             torch.Tensor(targets), | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._timestamp_generator) | ||||
|         return len(self._time_generator) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             cur_num=len(self), | ||||
|             total=len(self._timestamp_generator), | ||||
|             total=len(self._time_generator), | ||||
|             ndim=self._ndim, | ||||
|             num_per_task=self._num_per_task, | ||||
|             xrange_min=self.min_timestamp, | ||||
|             xrange_max=self.max_timestamp, | ||||
|             mode=self._timestamp_generator.mode, | ||||
|             mode=self.mode, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class EnvSampler: | ||||
|     def __init__(self, env, batch, enlarge): | ||||
|         indexes = list(range(len(env))) | ||||
|         self._indexes = indexes * enlarge | ||||
|         self._batch = batch | ||||
|         self._iterations = len(self._indexes) // self._batch | ||||
|  | ||||
|     def __iter__(self): | ||||
|         random.shuffle(self._indexes) | ||||
|         for it in range(self._iterations): | ||||
|             indexes = self._indexes[it * self._batch : (it + 1) * self._batch] | ||||
|             yield indexes | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._iterations | ||||
|   | ||||
| @@ -1,72 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
|  | ||||
| from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
| from .synthetic_env import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): | ||||
|     if indicator == "v1": | ||||
|         return create_example_v1(timestamp_config, num_per_task) | ||||
|     elif indicator == "v2": | ||||
|         return create_example_v2(timestamp_config, num_per_task) | ||||
|     else: | ||||
|         raise ValueError("Unkonwn indicator: {:}".format(indicator)) | ||||
|  | ||||
|  | ||||
| def create_example_v1( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| ): | ||||
|     mean_generator = ComposedSinFunc() | ||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) | ||||
|  | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=timestamp_config, | ||||
|     ) | ||||
|  | ||||
|     function = DynamicQuadraticFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function_param[2] = ComposedSinFunc( | ||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|     ) | ||||
|     function.set(function_param) | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
|  | ||||
|  | ||||
| def create_example_v2( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| ): | ||||
|     mean_generator = ConstantFunc(0) | ||||
|     std_generator = ConstantFunc(1) | ||||
|  | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=timestamp_config, | ||||
|     ) | ||||
|  | ||||
|     function = DynamicLinearFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function.set(function_param) | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
| @@ -13,11 +13,11 @@ class UnifiedSplit: | ||||
|     """A class to unify the split strategy.""" | ||||
|  | ||||
|     def __init__(self, total_num, mode): | ||||
|         # Training Set 60% | ||||
|         num_of_train = int(total_num * 0.6) | ||||
|         # Validation Set 20% | ||||
|         num_of_valid = int(total_num * 0.2) | ||||
|         # Test Set 20% | ||||
|         # Training Set 65% | ||||
|         num_of_train = int(total_num * 0.65) | ||||
|         # Validation Set 05% | ||||
|         num_of_valid = int(total_num * 0.05) | ||||
|         # Test Set 30% | ||||
|         num_of_set = total_num - num_of_train - num_of_valid | ||||
|         all_indexes = list(range(total_num)) | ||||
|         if mode is None: | ||||
| @@ -28,6 +28,8 @@ class UnifiedSplit: | ||||
|             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] | ||||
|         elif mode.lower() in ("test", "testing"): | ||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||
|         elif mode.lower() in ("trainval", "trainvalidation"): | ||||
|             self._indexes = all_indexes[: num_of_train + num_of_valid] | ||||
|         else: | ||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||
|         self._all_indexes = all_indexes | ||||
|   | ||||
		Reference in New Issue
	
	Block a user