Update the sync data v1
This commit is contained in:
		
							
								
								
									
										6
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -56,12 +56,16 @@ jobs: | |||||||
|           python -m pytest ./tests/test_basic_space.py -s |           python -m pytest ./tests/test_basic_space.py -s | ||||||
|         shell: bash |         shell: bash | ||||||
|  |  | ||||||
|       - name: Test Synthetic Data |       - name: Test Math | ||||||
|         run: | |         run: | | ||||||
|           python -m pip install pytest numpy |           python -m pip install pytest numpy | ||||||
|           python -m pip install parameterized |           python -m pip install parameterized | ||||||
|           python -m pip install torch torchvision |           python -m pip install torch torchvision | ||||||
|           python --version |           python --version | ||||||
|           python -m pytest ./tests/test_math*.py -s |           python -m pytest ./tests/test_math*.py -s | ||||||
|  |         shell: bash | ||||||
|  |  | ||||||
|  |       - name: Test Synthetic Data | ||||||
|  |         run: | | ||||||
|           python -m pytest ./tests/test_synthetic*.py -s |           python -m pytest ./tests/test_synthetic*.py -s | ||||||
|         shell: bash |         shell: bash | ||||||
|   | |||||||
| @@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, model_kwargs = lfna_setup(args) | ||||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) |     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) |     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||||
|     all_env = get_synthetic_env(mode=None, version=args.env_version) |     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||||
|   | |||||||
| @@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env | |||||||
| def lfna_setup(args): | def lfna_setup(args): | ||||||
|     prepare_seed(args.rand_seed) |     prepare_seed(args.rand_seed) | ||||||
|     logger = prepare_logger(args) |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|     cache_path = ( |  | ||||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) |  | ||||||
|     ).resolve() |  | ||||||
|     if cache_path.exists(): |  | ||||||
|         env_info = torch.load(cache_path) |  | ||||||
|     else: |  | ||||||
|         env_info = dict() |  | ||||||
|         dynamic_env = get_synthetic_env(version=args.env_version) |  | ||||||
|         env_info["total"] = len(dynamic_env) |  | ||||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): |  | ||||||
|             env_info["{:}-timestamp".format(idx)] = timestamp |  | ||||||
|             env_info["{:}-x".format(idx)] = _allx |  | ||||||
|             env_info["{:}-y".format(idx)] = _ally |  | ||||||
|         env_info["dynamic_env"] = dynamic_env |  | ||||||
|         torch.save(env_info, cache_path) |  | ||||||
|  |  | ||||||
|     """ |  | ||||||
|     model_kwargs = dict( |  | ||||||
|         config=dict(model_type="simple_mlp"), |  | ||||||
|         input_dim=1, |  | ||||||
|         output_dim=1, |  | ||||||
|         hidden_dim=args.hidden_dim, |  | ||||||
|         act_cls="leaky_relu", |  | ||||||
|         norm_cls="identity", |  | ||||||
|     ) |  | ||||||
|     """ |  | ||||||
|     model_kwargs = dict( |     model_kwargs = dict( | ||||||
|         config=dict(model_type="norm_mlp"), |         config=dict(model_type="norm_mlp"), | ||||||
|         input_dim=1, |         input_dim=1, | ||||||
| @@ -46,7 +19,7 @@ def lfna_setup(args): | |||||||
|         act_cls="gelu", |         act_cls="gelu", | ||||||
|         norm_cls="layer_norm_1d", |         norm_cls="layer_norm_1d", | ||||||
|     ) |     ) | ||||||
|     return logger, env_info, model_kwargs |     return logger, model_kwargs | ||||||
|  |  | ||||||
|  |  | ||||||
| def train_model(model, dataset, lr, epochs): | def train_model(model, dataset, lr, epochs): | ||||||
|   | |||||||
| @@ -20,14 +20,13 @@ matplotlib.use("agg") | |||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| import matplotlib.ticker as ticker | import matplotlib.ticker as ticker | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
| from datasets.synthetic_core import get_synthetic_env | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
| from utils.temp_sync import optimize_fn, evaluate_fn | from xautodl.procedures.metric_utils import MSEMetric | ||||||
| from procedures.metric_utils import MSEMetric |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||||
| @@ -181,10 +180,17 @@ def compare_cl(save_dir): | |||||||
|  |  | ||||||
| def visualize_env(save_dir, version): | def visualize_env(save_dir, version): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     for substr in ("pdf", "png"): | ||||||
|  |         sub_save_dir = save_dir / substr | ||||||
|  |         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|     dynamic_env = get_synthetic_env(version=version) |     dynamic_env = get_synthetic_env(version=version) | ||||||
|     min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp |     # min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp | ||||||
|  |     allxs, allys = [], [] | ||||||
|  |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|  |         allxs.append(allx) | ||||||
|  |         allys.append(ally) | ||||||
|  |     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         dpi, width, height = 30, 1800, 1400 |         dpi, width, height = 30, 1800, 1400 | ||||||
|         figsize = width / float(dpi), height / float(dpi) |         figsize = width / float(dpi), height / float(dpi) | ||||||
| @@ -201,21 +207,18 @@ def visualize_env(save_dir, version): | |||||||
|             tick.label.set_rotation(10) |             tick.label.set_rotation(10) | ||||||
|         for tick in cur_ax.yaxis.get_major_ticks(): |         for tick in cur_ax.yaxis.get_major_ticks(): | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |             tick.label.set_fontsize(LabelSize - font_gap) | ||||||
|         if version == "v1": |         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||||
|             cur_ax.set_xlim(-2, 2) |         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||||
|             cur_ax.set_ylim(-8, 8) |  | ||||||
|         elif version == "v2": |  | ||||||
|             cur_ax.set_xlim(-10, 10) |  | ||||||
|             cur_ax.set_ylim(-60, 60) |  | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         save_path = save_dir / "v{:}-{:05d}".format(version, idx) |         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) | ||||||
|         fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") |         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||||
|         fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") |         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) | ||||||
|  |         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||||
|         plt.close("all") |         plt.close("all") | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( |     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||||
|         xdir=save_dir, version=version |         xdir=save_dir / "png", version=version | ||||||
|     ) |     ) | ||||||
|     print(base_cmd) |     print(base_cmd) | ||||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
| @@ -371,7 +374,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") |     visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) |     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) |     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||||
|   | |||||||
| @@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config | |||||||
|  |  | ||||||
| # NAS-Bench-201 related module or function | # NAS-Bench-201 related module or function | ||||||
| from xautodl.models import CellStructure, get_cell_based_tiny_net | from xautodl.models import CellStructure, get_cell_based_tiny_net | ||||||
| from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders | from xautodl.procedures import ( | ||||||
|  |     bench_pure_evaluate as pure_evaluate, | ||||||
|  |     get_nas_bench_loaders, | ||||||
|  | ) | ||||||
| from nas_201_api import NASBench201API, ArchResults, ResultsCount | from nas_201_api import NASBench201API, ArchResults, ResultsCount | ||||||
|  |  | ||||||
| api = NASBench201API( | api = NASBench201API( | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								exps/experimental/test-dynamic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||||
|  | ##################################################### | ||||||
|  | # python test-dynamic.py | ||||||
|  | ##################################################### | ||||||
|  | import sys | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from xautodl.datasets.math_core import ConstantFunc | ||||||
|  | from xautodl.datasets.math_core import GaussianDGenerator | ||||||
|  |  | ||||||
|  | mean_generator = ConstantFunc(0) | ||||||
|  | cov_generator = ConstantFunc(1) | ||||||
|  |  | ||||||
|  | generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1)) | ||||||
|  | generator(0, 10) | ||||||
| @@ -19,9 +19,11 @@ import seaborn as sns | |||||||
| matplotlib.use("agg") | matplotlib.use("agg") | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | print("LIB-DIR: {:}".format(lib_dir)) | ||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from log_utils import time_string | from log_utils import time_string | ||||||
| from nats_bench import create | from nats_bench import create | ||||||
| from models import get_cell_based_tiny_net | from models import get_cell_based_tiny_net | ||||||
|   | |||||||
| @@ -3,11 +3,7 @@ from copy import deepcopy | |||||||
| import torchvision.models as models | import torchvision.models as models | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | from xautodl.utils import weight_watcher | ||||||
| if str(lib_dir) not in sys.path: |  | ||||||
|     sys.path.insert(0, str(lib_dir)) |  | ||||||
|  |  | ||||||
| from utils import weight_watcher |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|   | |||||||
| @@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc | |||||||
| class ConstantFunc(FitFunc): | class ConstantFunc(FitFunc): | ||||||
|     """The constant function: f(x) = c.""" |     """The constant function: f(x) = c.""" | ||||||
|  |  | ||||||
|     def __init__(self, constant=None): |     def __init__(self, constant=None, xstr="x"): | ||||||
|         param = dict() |         param = dict() | ||||||
|         param[0] = constant |         param[0] = constant | ||||||
|         super(ConstantFunc, self).__init__(0, None, param) |         super(ConstantFunc, self).__init__(0, None, param, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
| @@ -37,6 +37,34 @@ class ConstantFunc(FitFunc): | |||||||
|  |  | ||||||
|  |  | ||||||
| class ComposedSinFunc(FitFunc): | class ComposedSinFunc(FitFunc): | ||||||
|  |     """The composed sin function that outputs: | ||||||
|  |     f(x) = a * sin( b*x ) + c | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params, xstr="x"): | ||||||
|  |         super(ComposedSinFunc, self).__init__(3, None, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         a = self._params[0] | ||||||
|  |         b = self._params[1] | ||||||
|  |         c = self._params[2] | ||||||
|  |         return a * math.sin(b * x) + c | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ComposedSinFuncV2(FitFunc): | ||||||
|     """The composed sin function that outputs: |     """The composed sin function that outputs: | ||||||
|       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) |       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) | ||||||
|     - the amplitude scale is a quadratic function of x |     - the amplitude scale is a quadratic function of x | ||||||
| @@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
|         super(ComposedSinFunc, self).__init__(0, None) |         super(ComposedSinFuncV2, self).__init__(0, None) | ||||||
|         self.fit(**kwargs) |         self.fit(**kwargs) | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|   | |||||||
| @@ -5,15 +5,13 @@ import math | |||||||
| import abc | import abc | ||||||
| import copy | import copy | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Optional |  | ||||||
| import torch | import torch | ||||||
| import torch.utils.data as data |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FitFunc(abc.ABC): | class FitFunc(abc.ABC): | ||||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" |     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|     def __init__(self, freedom: int, list_of_points=None, params=None): |     def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): | ||||||
|         self._params = dict() |         self._params = dict() | ||||||
|         for i in range(freedom): |         for i in range(freedom): | ||||||
|             self._params[i] = None |             self._params[i] = None | ||||||
| @@ -24,6 +22,7 @@ class FitFunc(abc.ABC): | |||||||
|             self.fit(list_of_points=list_of_points) |             self.fit(list_of_points=list_of_points) | ||||||
|         if params is not None: |         if params is not None: | ||||||
|             self.set(params) |             self.set(params) | ||||||
|  |         self._xstr = str(xstr) | ||||||
|  |  | ||||||
|     def set(self, params): |     def set(self, params): | ||||||
|         self._params = copy.deepcopy(params) |         self._params = copy.deepcopy(params) | ||||||
| @@ -33,6 +32,13 @@ class FitFunc(abc.ABC): | |||||||
|             if value is None: |             if value is None: | ||||||
|                 raise ValueError("The {:} is None".format(key)) |                 raise ValueError("The {:} is None".format(key)) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def xstr(self): | ||||||
|  |         return self._xstr | ||||||
|  |  | ||||||
|  |     def reset_xstr(self, xstr): | ||||||
|  |         self._xstr = str(xstr) | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| @@ -106,8 +112,8 @@ class FitFunc(abc.ABC): | |||||||
| class LinearFunc(FitFunc): | class LinearFunc(FitFunc): | ||||||
|     """The linear function that outputs f(x) = a * x + b.""" |     """The linear function that outputs f(x) = a * x + b.""" | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None, params=None): |     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||||
|         super(LinearFunc, self).__init__(2, list_of_points, params) |         super(LinearFunc, self).__init__(2, list_of_points, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
| @@ -117,18 +123,19 @@ class LinearFunc(FitFunc): | |||||||
|         return weights[0] * x + weights[1] |         return weights[0] * x + weights[1] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x + {b})".format( |         return "{name}({a} * {x} + {b})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuadraticFunc(FitFunc): | class QuadraticFunc(FitFunc): | ||||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None, params=None): |     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params) |         super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
| @@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc): | |||||||
|         return weights[0] * x * x + weights[1] * x + weights[2] |         return weights[0] * x * x + weights[1] * x + weights[2] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( |         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             c=self._params[2], |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -165,12 +173,13 @@ class CubicFunc(FitFunc): | |||||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] |         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( |         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             c=self._params[2], |             c=self._params[2], | ||||||
|             d=self._params[3], |             d=self._params[3], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc | |||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc | from .math_adv_funcs import ConstantFunc | ||||||
| from .math_adv_funcs import ComposedSinFunc | from .math_adv_funcs import ComposedSinFunc | ||||||
|  | from .math_dynamic_generator import GaussianDGenerator | ||||||
|   | |||||||
| @@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc | |||||||
| class DynamicFunc(FitFunc): | class DynamicFunc(FitFunc): | ||||||
|     """The dynamic quadratic function, where each param is a function.""" |     """The dynamic quadratic function, where each param is a function.""" | ||||||
|  |  | ||||||
|     def __init__(self, freedom: int, params=None): |     def __init__(self, freedom: int, params=None, xstr="x"): | ||||||
|         super(DynamicFunc, self).__init__(freedom, None, params) |         if params is not None: | ||||||
|         self._timestamp = None |             for param in params: | ||||||
|  |                 param.reset_xstr("t") if isinstance(param, FitFunc) else None | ||||||
|  |         super(DynamicFunc, self).__init__(freedom, None, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp=None): |     def __call__(self, x, timestamp): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |     def _getitem(self, x, weights): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def set_timestamp(self, timestamp): |     def noise_call(self, x, timestamp, std): | ||||||
|         self._timestamp = timestamp |  | ||||||
|  |  | ||||||
|     def noise_call(self, x, timestamp=None, std=0.1): |  | ||||||
|         clean_y = self.__call__(x, timestamp) |         clean_y = self.__call__(x, timestamp) | ||||||
|         if isinstance(clean_y, np.ndarray): |         if isinstance(clean_y, np.ndarray): | ||||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
| @@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc): | |||||||
|     The a and b is a function of timestamp. |     The a and b is a function of timestamp. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, params=None): |     def __init__(self, params=None, xstr="x"): | ||||||
|         super(DynamicLinearFunc, self).__init__(3, params) |         super(DynamicLinearFunc, self).__init__(3, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp=None): |     def __call__(self, x, timestamp): | ||||||
|         self.check_valid() |  | ||||||
|         if timestamp is None: |  | ||||||
|             timestamp = self._timestamp |  | ||||||
|         a = self._params[0](timestamp) |         a = self._params[0](timestamp) | ||||||
|         b = self._params[1](timestamp) |         b = self._params[1](timestamp) | ||||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
| @@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc): | |||||||
|         return a * x + b |         return a * x + b | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x + {b}, timestamp={timestamp})".format( |         return "{name}({a} * {x} + {b})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             timestamp=self._timestamp, |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										58
									
								
								xautodl/datasets/math_dynamic_generator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								xautodl/datasets/math_dynamic_generator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | import abc | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def assert_list_tuple(x): | ||||||
|  |     assert isinstance(x, (list, tuple)) | ||||||
|  |     return len(x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DynamicGenerator(abc.ABC): | ||||||
|  |     """The dynamic quadratic function, where each param is a function.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self._ndim = None | ||||||
|  |  | ||||||
|  |     def __call__(self, time, num): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GaussianDGenerator(DynamicGenerator): | ||||||
|  |     def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): | ||||||
|  |         super(GaussianDGenerator, self).__init__() | ||||||
|  |         self._ndim = assert_list_tuple(mean_functors) | ||||||
|  |         assert self._ndim == len( | ||||||
|  |             cov_functors | ||||||
|  |         ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) | ||||||
|  |         assert_list_tuple(cov_functors) | ||||||
|  |         for cov_functor in cov_functors: | ||||||
|  |             assert self._ndim == assert_list_tuple( | ||||||
|  |                 cov_functor | ||||||
|  |             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) | ||||||
|  |         assert ( | ||||||
|  |             isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1] | ||||||
|  |         ) | ||||||
|  |         self._mean_functors = mean_functors | ||||||
|  |         self._cov_functors = cov_functors | ||||||
|  |         if trunc is not None: | ||||||
|  |             assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] | ||||||
|  |         self._trunc = trunc | ||||||
|  |  | ||||||
|  |     def __call__(self, time, num): | ||||||
|  |         mean_list = [functor(time) for functor in self._mean_functors] | ||||||
|  |         cov_matrix = [ | ||||||
|  |             [abs(cov_gen(time)) for cov_gen in cov_functor] | ||||||
|  |             for cov_functor in self._cov_functors | ||||||
|  |         ] | ||||||
|  |         values = np.random.multivariate_normal(mean_list, cov_matrix, size=num) | ||||||
|  |         if self._trunc is not None: | ||||||
|  |             np.clip(values, self._trunc[0], self._trunc[1], out=values) | ||||||
|  |         return values | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({ndim} dims, trunc={trunc})".format( | ||||||
|  |             name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc | ||||||
|  |         ) | ||||||
| @@ -1,13 +1,14 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | import math | ||||||
| from .synthetic_utils import TimeStamp | from .synthetic_utils import TimeStamp | ||||||
| from .synthetic_env import EnvSampler |  | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_core import LinearFunc | from .math_core import LinearFunc | ||||||
| from .math_core import DynamicLinearFunc | from .math_core import DynamicLinearFunc | ||||||
| from .math_core import DynamicQuadraticFunc | from .math_core import DynamicQuadraticFunc | ||||||
| from .math_core import ConstantFunc, ComposedSinFunc | from .math_core import ConstantFunc, ComposedSinFunc | ||||||
|  | from .math_core import GaussianDGenerator | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||||
| @@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio | |||||||
|     if version == "v1": |     if version == "v1": | ||||||
|         mean_generator = ConstantFunc(0) |         mean_generator = ConstantFunc(0) | ||||||
|         std_generator = ConstantFunc(1) |         std_generator = ConstantFunc(1) | ||||||
|     elif version == "v2": |         data_generator = GaussianDGenerator( | ||||||
|         mean_generator = ComposedSinFunc() |             [mean_generator], [[std_generator]], (-2, 2) | ||||||
|         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) |         ) | ||||||
|     else: |         time_generator = TimeStamp( | ||||||
|         raise ValueError("Unknown version: {:}".format(version)) |             min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode | ||||||
|  |         ) | ||||||
|  |         oracle_map = DynamicLinearFunc( | ||||||
|  |             params={ | ||||||
|  |                 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), | ||||||
|  |                 1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|         [mean_generator], |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         [[std_generator]], |  | ||||||
|         num_per_task=num_per_task, |  | ||||||
|         timestamp_config=dict( |  | ||||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     if version == "v1": |  | ||||||
|         function = DynamicLinearFunc() |  | ||||||
|         function_param = dict() |  | ||||||
|         function_param[0] = ComposedSinFunc( |  | ||||||
|             amplitude_scale=ConstantFunc(3.0), |  | ||||||
|             num_sin_phase=9, |  | ||||||
|             sin_speed_use_power=False, |  | ||||||
|         ) |  | ||||||
|         function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|     elif version == "v2": |  | ||||||
|         function = DynamicQuadraticFunc() |  | ||||||
|         function_param = dict() |  | ||||||
|         function_param[0] = ComposedSinFunc( |  | ||||||
|             num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 |  | ||||||
|         ) |  | ||||||
|         function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|         function_param[2] = ComposedSinFunc( |  | ||||||
|             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 |  | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown version: {:}".format(version)) |         raise ValueError("Unknown version: {:}".format(version)) | ||||||
|  |  | ||||||
|     function.set(function_param) |  | ||||||
|     # dynamic_env.set_oracle_map(copy.deepcopy(function)) |  | ||||||
|     dynamic_env.set_oracle_map(function) |  | ||||||
|     return dynamic_env |     return dynamic_env | ||||||
|   | |||||||
| @@ -1,15 +1,9 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # |  | ||||||
| ##################################################### |  | ||||||
| import math | import math | ||||||
| import random | import random | ||||||
| import numpy as np |  | ||||||
| from typing import List, Optional, Dict | from typing import List, Optional, Dict | ||||||
| import torch | import torch | ||||||
| import torch.utils.data as data | import torch.utils.data as data | ||||||
|  |  | ||||||
| from .synthetic_utils import TimeStamp |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_list_tuple(x): | def is_list_tuple(x): | ||||||
|     return isinstance(x, (tuple, list)) |     return isinstance(x, (tuple, list)) | ||||||
| @@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset): | |||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         mean_functors: List[data.Dataset], |         data_generator, | ||||||
|         cov_functors: List[List[data.Dataset]], |         oracle_map, | ||||||
|  |         time_generator, | ||||||
|         num_per_task: int = 5000, |         num_per_task: int = 5000, | ||||||
|         timestamp_config: Optional[Dict] = None, |         noise: float = 0.1, | ||||||
|         mode: Optional[str] = None, |  | ||||||
|         timestamp_noise_scale: float = 0.3, |  | ||||||
|     ): |     ): | ||||||
|         self._ndim = len(mean_functors) |         self._data_generator = data_generator | ||||||
|         assert self._ndim == len( |         self._time_generator = time_generator | ||||||
|             cov_functors |         self._oracle_map = oracle_map | ||||||
|         ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) |  | ||||||
|         for cov_functor in cov_functors: |  | ||||||
|             assert self._ndim == len( |  | ||||||
|                 cov_functor |  | ||||||
|             ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) |  | ||||||
|         self._num_per_task = num_per_task |         self._num_per_task = num_per_task | ||||||
|         if timestamp_config is None: |         self._noise = noise | ||||||
|             timestamp_config = dict(mode=mode) |  | ||||||
|         elif "mode" not in timestamp_config: |  | ||||||
|             timestamp_config["mode"] = mode |  | ||||||
|  |  | ||||||
|         self._timestamp_generator = TimeStamp(**timestamp_config) |  | ||||||
|         self._timestamp_noise_scale = timestamp_noise_scale |  | ||||||
|  |  | ||||||
|         self._mean_functors = mean_functors |  | ||||||
|         self._cov_functors = cov_functors |  | ||||||
|  |  | ||||||
|         self._oracle_map = None |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def min_timestamp(self): |     def min_timestamp(self): | ||||||
|         return self._timestamp_generator.min_timestamp |         return self._time_generator.min_timestamp | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def max_timestamp(self): |     def max_timestamp(self): | ||||||
|         return self._timestamp_generator.max_timestamp |         return self._time_generator.max_timestamp | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def timestamp_interval(self): |     def time_interval(self): | ||||||
|         return self._timestamp_generator.interval |         return self._time_generator.interval | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def mode(self): | ||||||
|  |         return self._time_generator.mode | ||||||
|  |  | ||||||
|     def random_timestamp(self, min_timestamp=None, max_timestamp=None): |     def random_timestamp(self, min_timestamp=None, max_timestamp=None): | ||||||
|         if min_timestamp is None: |         if min_timestamp is None: | ||||||
| @@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset): | |||||||
|     def get_timestamp(self, index): |     def get_timestamp(self, index): | ||||||
|         if index is None: |         if index is None: | ||||||
|             timestamps = [] |             timestamps = [] | ||||||
|             for index in range(len(self._timestamp_generator)): |             for index in range(len(self._time_generator)): | ||||||
|                 timestamps.append(self._timestamp_generator[index][1]) |                 timestamps.append(self._time_generator[index][1]) | ||||||
|             return tuple(timestamps) |             return tuple(timestamps) | ||||||
|         else: |         else: | ||||||
|             index, timestamp = self._timestamp_generator[index] |             index, timestamp = self._time_generator[index] | ||||||
|             return timestamp |             return timestamp | ||||||
|  |  | ||||||
|     def set_oracle_map(self, functor): |  | ||||||
|         self._oracle_map = functor |  | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         self._iter_num = 0 |         self._iter_num = 0 | ||||||
|         return self |         return self | ||||||
| @@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|  |  | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) |         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||||
|         index, timestamp = self._timestamp_generator[index] |         index, timestamp = self._time_generator[index] | ||||||
|         return self.__call__(timestamp) |         return self.__call__(timestamp) | ||||||
|  |  | ||||||
|     def seq_call(self, timestamps): |     def seq_call(self, timestamps): | ||||||
| @@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset): | |||||||
|             return zip_sequence(xdata) |             return zip_sequence(xdata) | ||||||
|  |  | ||||||
|     def __call__(self, timestamp): |     def __call__(self, timestamp): | ||||||
|         mean_list = [functor(timestamp) for functor in self._mean_functors] |         dataset = self._data_generator(timestamp, self._num_per_task) | ||||||
|         cov_matrix = [ |         targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) | ||||||
|             [abs(cov_gen(timestamp)) for cov_gen in cov_functor] |  | ||||||
|             for cov_functor in self._cov_functors |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|         dataset = np.random.multivariate_normal( |  | ||||||
|             mean_list, cov_matrix, size=self._num_per_task |  | ||||||
|         ) |  | ||||||
|         if self._oracle_map is None: |  | ||||||
|             return torch.Tensor([timestamp]), torch.Tensor(dataset) |  | ||||||
|         else: |  | ||||||
|             targets = self._oracle_map.noise_call(dataset, timestamp) |  | ||||||
|         return torch.Tensor([timestamp]), ( |         return torch.Tensor([timestamp]), ( | ||||||
|             torch.Tensor(dataset), |             torch.Tensor(dataset), | ||||||
|             torch.Tensor(targets), |             torch.Tensor(targets), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._timestamp_generator) |         return len(self._time_generator) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( |         return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             cur_num=len(self), |             cur_num=len(self), | ||||||
|             total=len(self._timestamp_generator), |             total=len(self._time_generator), | ||||||
|             ndim=self._ndim, |             ndim=self._ndim, | ||||||
|             num_per_task=self._num_per_task, |             num_per_task=self._num_per_task, | ||||||
|             xrange_min=self.min_timestamp, |             xrange_min=self.min_timestamp, | ||||||
|             xrange_max=self.max_timestamp, |             xrange_max=self.max_timestamp, | ||||||
|             mode=self._timestamp_generator.mode, |             mode=self.mode, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnvSampler: |  | ||||||
|     def __init__(self, env, batch, enlarge): |  | ||||||
|         indexes = list(range(len(env))) |  | ||||||
|         self._indexes = indexes * enlarge |  | ||||||
|         self._batch = batch |  | ||||||
|         self._iterations = len(self._indexes) // self._batch |  | ||||||
|  |  | ||||||
|     def __iter__(self): |  | ||||||
|         random.shuffle(self._indexes) |  | ||||||
|         for it in range(self._iterations): |  | ||||||
|             indexes = self._indexes[it * self._batch : (it + 1) * self._batch] |  | ||||||
|             yield indexes |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return self._iterations |  | ||||||
|   | |||||||
| @@ -1,72 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # |  | ||||||
| ##################################################### |  | ||||||
| import copy |  | ||||||
|  |  | ||||||
| from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc |  | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc |  | ||||||
| from .synthetic_env import SyntheticDEnv |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): |  | ||||||
|     if indicator == "v1": |  | ||||||
|         return create_example_v1(timestamp_config, num_per_task) |  | ||||||
|     elif indicator == "v2": |  | ||||||
|         return create_example_v2(timestamp_config, num_per_task) |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Unkonwn indicator: {:}".format(indicator)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_example_v1( |  | ||||||
|     timestamp_config=None, |  | ||||||
|     num_per_task=5000, |  | ||||||
| ): |  | ||||||
|     mean_generator = ComposedSinFunc() |  | ||||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) |  | ||||||
|  |  | ||||||
|     dynamic_env = SyntheticDEnv( |  | ||||||
|         [mean_generator], |  | ||||||
|         [[std_generator]], |  | ||||||
|         num_per_task=num_per_task, |  | ||||||
|         timestamp_config=timestamp_config, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     function = DynamicQuadraticFunc() |  | ||||||
|     function_param = dict() |  | ||||||
|     function_param[0] = ComposedSinFunc( |  | ||||||
|         num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 |  | ||||||
|     ) |  | ||||||
|     function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|     function_param[2] = ComposedSinFunc( |  | ||||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 |  | ||||||
|     ) |  | ||||||
|     function.set(function_param) |  | ||||||
|  |  | ||||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) |  | ||||||
|     return dynamic_env, function |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_example_v2( |  | ||||||
|     timestamp_config=None, |  | ||||||
|     num_per_task=5000, |  | ||||||
| ): |  | ||||||
|     mean_generator = ConstantFunc(0) |  | ||||||
|     std_generator = ConstantFunc(1) |  | ||||||
|  |  | ||||||
|     dynamic_env = SyntheticDEnv( |  | ||||||
|         [mean_generator], |  | ||||||
|         [[std_generator]], |  | ||||||
|         num_per_task=num_per_task, |  | ||||||
|         timestamp_config=timestamp_config, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     function = DynamicLinearFunc() |  | ||||||
|     function_param = dict() |  | ||||||
|     function_param[0] = ComposedSinFunc( |  | ||||||
|         amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) |  | ||||||
|     ) |  | ||||||
|     function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|     function.set(function_param) |  | ||||||
|  |  | ||||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) |  | ||||||
|     return dynamic_env, function |  | ||||||
| @@ -13,11 +13,11 @@ class UnifiedSplit: | |||||||
|     """A class to unify the split strategy.""" |     """A class to unify the split strategy.""" | ||||||
|  |  | ||||||
|     def __init__(self, total_num, mode): |     def __init__(self, total_num, mode): | ||||||
|         # Training Set 60% |         # Training Set 65% | ||||||
|         num_of_train = int(total_num * 0.6) |         num_of_train = int(total_num * 0.65) | ||||||
|         # Validation Set 20% |         # Validation Set 05% | ||||||
|         num_of_valid = int(total_num * 0.2) |         num_of_valid = int(total_num * 0.05) | ||||||
|         # Test Set 20% |         # Test Set 30% | ||||||
|         num_of_set = total_num - num_of_train - num_of_valid |         num_of_set = total_num - num_of_train - num_of_valid | ||||||
|         all_indexes = list(range(total_num)) |         all_indexes = list(range(total_num)) | ||||||
|         if mode is None: |         if mode is None: | ||||||
| @@ -28,6 +28,8 @@ class UnifiedSplit: | |||||||
|             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] |             self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] | ||||||
|         elif mode.lower() in ("test", "testing"): |         elif mode.lower() in ("test", "testing"): | ||||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] |             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||||
|  |         elif mode.lower() in ("trainval", "trainvalidation"): | ||||||
|  |             self._indexes = all_indexes[: num_of_train + num_of_valid] | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) |             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||||
|         self._all_indexes = all_indexes |         self._all_indexes = all_indexes | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user