diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index d6b5a4c..942454c 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -56,12 +56,16 @@ jobs: python -m pytest ./tests/test_basic_space.py -s shell: bash - - name: Test Synthetic Data + - name: Test Math run: | python -m pip install pytest numpy python -m pip install parameterized python -m pip install torch torchvision python --version python -m pytest ./tests/test_math*.py -s + shell: bash + + - name: Test Synthetic Data + run: | python -m pytest ./tests/test_synthetic*.py -s shell: bash diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index f0ec021..71b9b0b 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): def main(args): - logger, env_info, model_kwargs = lfna_setup(args) + logger, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version) all_env = get_synthetic_env(mode=None, version=args.env_version) diff --git a/exps/LFNA/lfna_utils.py b/exps/LFNA/lfna_utils.py index 44489e1..ea9451d 100644 --- a/exps/LFNA/lfna_utils.py +++ b/exps/LFNA/lfna_utils.py @@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env def lfna_setup(args): prepare_seed(args.rand_seed) logger = prepare_logger(args) - - cache_path = ( - logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) - ).resolve() - if cache_path.exists(): - env_info = torch.load(cache_path) - else: - env_info = dict() - dynamic_env = get_synthetic_env(version=args.env_version) - env_info["total"] = len(dynamic_env) - for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): - env_info["{:}-timestamp".format(idx)] = timestamp - env_info["{:}-x".format(idx)] = _allx - env_info["{:}-y".format(idx)] = _ally - env_info["dynamic_env"] = dynamic_env - torch.save(env_info, cache_path) - - """ - model_kwargs = dict( - config=dict(model_type="simple_mlp"), - input_dim=1, - output_dim=1, - hidden_dim=args.hidden_dim, - act_cls="leaky_relu", - norm_cls="identity", - ) - """ model_kwargs = dict( config=dict(model_type="norm_mlp"), input_dim=1, @@ -46,7 +19,7 @@ def lfna_setup(args): act_cls="gelu", norm_cls="layer_norm_1d", ) - return logger, env_info, model_kwargs + return logger, model_kwargs def train_model(model, dataset, lr, epochs): diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 027776e..d30ac38 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -20,14 +20,13 @@ matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +lib_dir = (Path(__file__).parent / ".." / "..").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from models.xcore import get_model -from datasets.synthetic_core import get_synthetic_env -from utils.temp_sync import optimize_fn, evaluate_fn -from procedures.metric_utils import MSEMetric +from xautodl.models.xcore import get_model +from xautodl.datasets.synthetic_core import get_synthetic_env +from xautodl.procedures.metric_utils import MSEMetric def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): @@ -181,10 +180,17 @@ def compare_cl(save_dir): def visualize_env(save_dir, version): save_dir = Path(str(save_dir)) - save_dir.mkdir(parents=True, exist_ok=True) + for substr in ("pdf", "png"): + sub_save_dir = save_dir / substr + sub_save_dir.mkdir(parents=True, exist_ok=True) dynamic_env = get_synthetic_env(version=version) - min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp + # min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp + allxs, allys = [], [] + for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): + allxs.append(allx) + allys.append(ally) + allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): dpi, width, height = 30, 1800, 1400 figsize = width / float(dpi), height / float(dpi) @@ -201,21 +207,18 @@ def visualize_env(save_dir, version): tick.label.set_rotation(10) for tick in cur_ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize - font_gap) - if version == "v1": - cur_ax.set_xlim(-2, 2) - cur_ax.set_ylim(-8, 8) - elif version == "v2": - cur_ax.set_xlim(-10, 10) - cur_ax.set_ylim(-60, 60) + cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) + cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) cur_ax.legend(loc=1, fontsize=LegendFontsize) - save_path = save_dir / "v{:}-{:05d}".format(version, idx) - fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") - fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") + pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) + fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") + png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) + fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir, version=version + xdir=save_dir / "png", version=version ) print(base_cmd) os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) @@ -371,7 +374,7 @@ if __name__ == "__main__": ) args = parser.parse_args() - # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") + visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") - compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) + # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) # compare_cl(os.path.join(args.save_dir, "compare-cl")) diff --git a/exps/NAS-Bench-201/statistics-v2.py b/exps/NAS-Bench-201/statistics-v2.py index 037af0f..d1a54a1 100644 --- a/exps/NAS-Bench-201/statistics-v2.py +++ b/exps/NAS-Bench-201/statistics-v2.py @@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config # NAS-Bench-201 related module or function from xautodl.models import CellStructure, get_cell_based_tiny_net -from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders +from xautodl.procedures import ( + bench_pure_evaluate as pure_evaluate, + get_nas_bench_loaders, +) from nas_201_api import NASBench201API, ArchResults, ResultsCount api = NASBench201API( diff --git a/exps/experimental/test-dynamic.py b/exps/experimental/test-dynamic.py new file mode 100644 index 0000000..a1a1e28 --- /dev/null +++ b/exps/experimental/test-dynamic.py @@ -0,0 +1,21 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # +##################################################### +# python test-dynamic.py +##################################################### +import sys +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +from xautodl.datasets.math_core import ConstantFunc +from xautodl.datasets.math_core import GaussianDGenerator + +mean_generator = ConstantFunc(0) +cov_generator = ConstantFunc(1) + +generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1)) +generator(0, 10) diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index 21d7668..5f398b3 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -19,9 +19,11 @@ import seaborn as sns matplotlib.use("agg") import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +lib_dir = (Path(__file__).parent / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + from log_utils import time_string from nats_bench import create from models import get_cell_based_tiny_net diff --git a/exps/experimental/test-ww.py b/exps/experimental/test-ww.py index 626a273..97597ca 100644 --- a/exps/experimental/test-ww.py +++ b/exps/experimental/test-ww.py @@ -3,11 +3,7 @@ from copy import deepcopy import torchvision.models as models from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) - -from utils import weight_watcher +from xautodl.utils import weight_watcher def main(): diff --git a/xautodl/datasets/math_adv_funcs.py b/xautodl/datasets/math_adv_funcs.py index 2a093c7..e2529c0 100644 --- a/xautodl/datasets/math_adv_funcs.py +++ b/xautodl/datasets/math_adv_funcs.py @@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc class ConstantFunc(FitFunc): """The constant function: f(x) = c.""" - def __init__(self, constant=None): + def __init__(self, constant=None, xstr="x"): param = dict() param[0] = constant - super(ConstantFunc, self).__init__(0, None, param) + super(ConstantFunc, self).__init__(0, None, param, xstr) def __call__(self, x): self.check_valid() @@ -37,6 +37,34 @@ class ConstantFunc(FitFunc): class ComposedSinFunc(FitFunc): + """The composed sin function that outputs: + f(x) = a * sin( b*x ) + c + """ + + def __init__(self, params, xstr="x"): + super(ComposedSinFunc, self).__init__(3, None, params, xstr) + + def __call__(self, x): + self.check_valid() + a = self._params[0] + b = self._params[1] + c = self._params[2] + return a * math.sin(b * x) + c + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({a} * sin({b} * {x}) + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + x=self.xstr, + ) + + +class ComposedSinFuncV2(FitFunc): """The composed sin function that outputs: f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) - the amplitude scale is a quadratic function of x @@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc): """ def __init__(self, **kwargs): - super(ComposedSinFunc, self).__init__(0, None) + super(ComposedSinFuncV2, self).__init__(0, None) self.fit(**kwargs) def __call__(self, x): diff --git a/xautodl/datasets/math_base_funcs.py b/xautodl/datasets/math_base_funcs.py index a77634b..e560755 100644 --- a/xautodl/datasets/math_base_funcs.py +++ b/xautodl/datasets/math_base_funcs.py @@ -5,15 +5,13 @@ import math import abc import copy import numpy as np -from typing import Optional import torch -import torch.utils.data as data class FitFunc(abc.ABC): """The fit function that outputs f(x) = a * x^2 + b * x + c.""" - def __init__(self, freedom: int, list_of_points=None, params=None): + def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): self._params = dict() for i in range(freedom): self._params[i] = None @@ -24,6 +22,7 @@ class FitFunc(abc.ABC): self.fit(list_of_points=list_of_points) if params is not None: self.set(params) + self._xstr = str(xstr) def set(self, params): self._params = copy.deepcopy(params) @@ -33,6 +32,13 @@ class FitFunc(abc.ABC): if value is None: raise ValueError("The {:} is None".format(key)) + @property + def xstr(self): + return self._xstr + + def reset_xstr(self, xstr): + self._xstr = str(xstr) + @abc.abstractmethod def __call__(self, x): raise NotImplementedError @@ -106,8 +112,8 @@ class FitFunc(abc.ABC): class LinearFunc(FitFunc): """The linear function that outputs f(x) = a * x + b.""" - def __init__(self, list_of_points=None, params=None): - super(LinearFunc, self).__init__(2, list_of_points, params) + def __init__(self, list_of_points=None, params=None, xstr="x"): + super(LinearFunc, self).__init__(2, list_of_points, params, xstr) def __call__(self, x): self.check_valid() @@ -117,18 +123,19 @@ class LinearFunc(FitFunc): return weights[0] * x + weights[1] def __repr__(self): - return "{name}({a} * x + {b})".format( + return "{name}({a} * {x} + {b})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], + x=self.xstr, ) class QuadraticFunc(FitFunc): """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" - def __init__(self, list_of_points=None, params=None): - super(QuadraticFunc, self).__init__(3, list_of_points, params) + def __init__(self, list_of_points=None, params=None, xstr="x"): + super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) def __call__(self, x): self.check_valid() @@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc): return weights[0] * x * x + weights[1] * x + weights[2] def __repr__(self): - return "{name}({a} * x^2 + {b} * x + {c})".format( + return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], c=self._params[2], + x=self.xstr, ) @@ -165,12 +173,13 @@ class CubicFunc(FitFunc): return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] def __repr__(self): - return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( + return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], c=self._params[2], d=self._params[3], + x=self.xstr, ) diff --git a/xautodl/datasets/math_core.py b/xautodl/datasets/math_core.py index 5dd2429..2579282 100644 --- a/xautodl/datasets/math_core.py +++ b/xautodl/datasets/math_core.py @@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc from .math_dynamic_funcs import DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ComposedSinFunc +from .math_dynamic_generator import GaussianDGenerator diff --git a/xautodl/datasets/math_dynamic_funcs.py b/xautodl/datasets/math_dynamic_funcs.py index e4e43c4..e83d8db 100644 --- a/xautodl/datasets/math_dynamic_funcs.py +++ b/xautodl/datasets/math_dynamic_funcs.py @@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc class DynamicFunc(FitFunc): """The dynamic quadratic function, where each param is a function.""" - def __init__(self, freedom: int, params=None): - super(DynamicFunc, self).__init__(freedom, None, params) - self._timestamp = None + def __init__(self, freedom: int, params=None, xstr="x"): + if params is not None: + for param in params: + param.reset_xstr("t") if isinstance(param, FitFunc) else None + super(DynamicFunc, self).__init__(freedom, None, params, xstr) - def __call__(self, x, timestamp=None): + def __call__(self, x, timestamp): raise NotImplementedError def _getitem(self, x, weights): raise NotImplementedError - def set_timestamp(self, timestamp): - self._timestamp = timestamp - - def noise_call(self, x, timestamp=None, std=0.1): + def noise_call(self, x, timestamp, std): clean_y = self.__call__(x, timestamp) if isinstance(clean_y, np.ndarray): noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) @@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc): The a and b is a function of timestamp. """ - def __init__(self, params=None): - super(DynamicLinearFunc, self).__init__(3, params) + def __init__(self, params=None, xstr="x"): + super(DynamicLinearFunc, self).__init__(3, params, xstr) - def __call__(self, x, timestamp=None): - self.check_valid() - if timestamp is None: - timestamp = self._timestamp + def __call__(self, x, timestamp): a = self._params[0](timestamp) b = self._params[1](timestamp) convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x @@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc): return a * x + b def __repr__(self): - return "{name}({a} * x + {b}, timestamp={timestamp})".format( + return "{name}({a} * {x} + {b})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], - timestamp=self._timestamp, + x=self.xstr, ) diff --git a/xautodl/datasets/math_dynamic_generator.py b/xautodl/datasets/math_dynamic_generator.py new file mode 100644 index 0000000..33fc478 --- /dev/null +++ b/xautodl/datasets/math_dynamic_generator.py @@ -0,0 +1,58 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import abc +import numpy as np + + +def assert_list_tuple(x): + assert isinstance(x, (list, tuple)) + return len(x) + + +class DynamicGenerator(abc.ABC): + """The dynamic quadratic function, where each param is a function.""" + + def __init__(self): + self._ndim = None + + def __call__(self, time, num): + raise NotImplementedError + + +class GaussianDGenerator(DynamicGenerator): + def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): + super(GaussianDGenerator, self).__init__() + self._ndim = assert_list_tuple(mean_functors) + assert self._ndim == len( + cov_functors + ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) + assert_list_tuple(cov_functors) + for cov_functor in cov_functors: + assert self._ndim == assert_list_tuple( + cov_functor + ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) + assert ( + isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1] + ) + self._mean_functors = mean_functors + self._cov_functors = cov_functors + if trunc is not None: + assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] + self._trunc = trunc + + def __call__(self, time, num): + mean_list = [functor(time) for functor in self._mean_functors] + cov_matrix = [ + [abs(cov_gen(time)) for cov_gen in cov_functor] + for cov_functor in self._cov_functors + ] + values = np.random.multivariate_normal(mean_list, cov_matrix, size=num) + if self._trunc is not None: + np.clip(values, self._trunc[0], self._trunc[1], out=values) + return values + + def __repr__(self): + return "{name}({ndim} dims, trunc={trunc})".format( + name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc + ) diff --git a/xautodl/datasets/synthetic_core.py b/xautodl/datasets/synthetic_core.py index 5f2bfee..9a3eb2e 100644 --- a/xautodl/datasets/synthetic_core.py +++ b/xautodl/datasets/synthetic_core.py @@ -1,13 +1,14 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # ##################################################### +import math from .synthetic_utils import TimeStamp -from .synthetic_env import EnvSampler from .synthetic_env import SyntheticDEnv from .math_core import LinearFunc from .math_core import DynamicLinearFunc from .math_core import DynamicQuadraticFunc from .math_core import ConstantFunc, ComposedSinFunc +from .math_core import GaussianDGenerator __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] @@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio if version == "v1": mean_generator = ConstantFunc(0) std_generator = ConstantFunc(1) - elif version == "v2": - mean_generator = ComposedSinFunc() - std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) - else: - raise ValueError("Unknown version: {:}".format(version)) - dynamic_env = SyntheticDEnv( - [mean_generator], - [[std_generator]], - num_per_task=num_per_task, - timestamp_config=dict( - min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode - ), - ) - if version == "v1": - function = DynamicLinearFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - amplitude_scale=ConstantFunc(3.0), - num_sin_phase=9, - sin_speed_use_power=False, + data_generator = GaussianDGenerator( + [mean_generator], [[std_generator]], (-2, 2) ) - function_param[1] = ConstantFunc(constant=0.9) - elif version == "v2": - function = DynamicQuadraticFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 + time_generator = TimeStamp( + min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode ) - function_param[1] = ConstantFunc(constant=0.9) - function_param[2] = ComposedSinFunc( - num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 + oracle_map = DynamicLinearFunc( + params={ + 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), + 1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}), + } + ) + dynamic_env = SyntheticDEnv( + data_generator, oracle_map, time_generator, num_per_task ) else: raise ValueError("Unknown version: {:}".format(version)) - - function.set(function_param) - # dynamic_env.set_oracle_map(copy.deepcopy(function)) - dynamic_env.set_oracle_map(function) return dynamic_env diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index 8c7854c..66b5254 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -1,15 +1,9 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### import math import random -import numpy as np from typing import List, Optional, Dict import torch import torch.utils.data as data -from .synthetic_utils import TimeStamp - def is_list_tuple(x): return isinstance(x, (tuple, list)) @@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset): def __init__( self, - mean_functors: List[data.Dataset], - cov_functors: List[List[data.Dataset]], + data_generator, + oracle_map, + time_generator, num_per_task: int = 5000, - timestamp_config: Optional[Dict] = None, - mode: Optional[str] = None, - timestamp_noise_scale: float = 0.3, + noise: float = 0.1, ): - self._ndim = len(mean_functors) - assert self._ndim == len( - cov_functors - ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) - for cov_functor in cov_functors: - assert self._ndim == len( - cov_functor - ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) + self._data_generator = data_generator + self._time_generator = time_generator + self._oracle_map = oracle_map self._num_per_task = num_per_task - if timestamp_config is None: - timestamp_config = dict(mode=mode) - elif "mode" not in timestamp_config: - timestamp_config["mode"] = mode - - self._timestamp_generator = TimeStamp(**timestamp_config) - self._timestamp_noise_scale = timestamp_noise_scale - - self._mean_functors = mean_functors - self._cov_functors = cov_functors - - self._oracle_map = None + self._noise = noise @property def min_timestamp(self): - return self._timestamp_generator.min_timestamp + return self._time_generator.min_timestamp @property def max_timestamp(self): - return self._timestamp_generator.max_timestamp + return self._time_generator.max_timestamp @property - def timestamp_interval(self): - return self._timestamp_generator.interval + def time_interval(self): + return self._time_generator.interval + + @property + def mode(self): + return self._time_generator.mode def random_timestamp(self, min_timestamp=None, max_timestamp=None): if min_timestamp is None: @@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset): def get_timestamp(self, index): if index is None: timestamps = [] - for index in range(len(self._timestamp_generator)): - timestamps.append(self._timestamp_generator[index][1]) + for index in range(len(self._time_generator)): + timestamps.append(self._time_generator[index][1]) return tuple(timestamps) else: - index, timestamp = self._timestamp_generator[index] + index, timestamp = self._time_generator[index] return timestamp - def set_oracle_map(self, functor): - self._oracle_map = functor - def __iter__(self): self._iter_num = 0 return self @@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) - index, timestamp = self._timestamp_generator[index] + index, timestamp = self._time_generator[index] return self.__call__(timestamp) def seq_call(self, timestamps): @@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset): return zip_sequence(xdata) def __call__(self, timestamp): - mean_list = [functor(timestamp) for functor in self._mean_functors] - cov_matrix = [ - [abs(cov_gen(timestamp)) for cov_gen in cov_functor] - for cov_functor in self._cov_functors - ] - - dataset = np.random.multivariate_normal( - mean_list, cov_matrix, size=self._num_per_task + dataset = self._data_generator(timestamp, self._num_per_task) + targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) + return torch.Tensor([timestamp]), ( + torch.Tensor(dataset), + torch.Tensor(targets), ) - if self._oracle_map is None: - return torch.Tensor([timestamp]), torch.Tensor(dataset) - else: - targets = self._oracle_map.noise_call(dataset, timestamp) - return torch.Tensor([timestamp]), ( - torch.Tensor(dataset), - torch.Tensor(targets), - ) def __len__(self): - return len(self._timestamp_generator) + return len(self._time_generator) def __repr__(self): return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( name=self.__class__.__name__, cur_num=len(self), - total=len(self._timestamp_generator), + total=len(self._time_generator), ndim=self._ndim, num_per_task=self._num_per_task, xrange_min=self.min_timestamp, xrange_max=self.max_timestamp, - mode=self._timestamp_generator.mode, + mode=self.mode, ) - - -class EnvSampler: - def __init__(self, env, batch, enlarge): - indexes = list(range(len(env))) - self._indexes = indexes * enlarge - self._batch = batch - self._iterations = len(self._indexes) // self._batch - - def __iter__(self): - random.shuffle(self._indexes) - for it in range(self._iterations): - indexes = self._indexes[it * self._batch : (it + 1) * self._batch] - yield indexes - - def __len__(self): - return self._iterations diff --git a/xautodl/datasets/synthetic_example.py b/xautodl/datasets/synthetic_example.py deleted file mode 100644 index f5fea7b..0000000 --- a/xautodl/datasets/synthetic_example.py +++ /dev/null @@ -1,72 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -import copy - -from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc -from .math_adv_funcs import ConstantFunc, ComposedSinFunc -from .synthetic_env import SyntheticDEnv - - -def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): - if indicator == "v1": - return create_example_v1(timestamp_config, num_per_task) - elif indicator == "v2": - return create_example_v2(timestamp_config, num_per_task) - else: - raise ValueError("Unkonwn indicator: {:}".format(indicator)) - - -def create_example_v1( - timestamp_config=None, - num_per_task=5000, -): - mean_generator = ComposedSinFunc() - std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) - - dynamic_env = SyntheticDEnv( - [mean_generator], - [[std_generator]], - num_per_task=num_per_task, - timestamp_config=timestamp_config, - ) - - function = DynamicQuadraticFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 - ) - function_param[1] = ConstantFunc(constant=0.9) - function_param[2] = ComposedSinFunc( - num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 - ) - function.set(function_param) - - dynamic_env.set_oracle_map(copy.deepcopy(function)) - return dynamic_env, function - - -def create_example_v2( - timestamp_config=None, - num_per_task=5000, -): - mean_generator = ConstantFunc(0) - std_generator = ConstantFunc(1) - - dynamic_env = SyntheticDEnv( - [mean_generator], - [[std_generator]], - num_per_task=num_per_task, - timestamp_config=timestamp_config, - ) - - function = DynamicLinearFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) - ) - function_param[1] = ConstantFunc(constant=0.9) - function.set(function_param) - - dynamic_env.set_oracle_map(copy.deepcopy(function)) - return dynamic_env, function diff --git a/xautodl/datasets/synthetic_utils.py b/xautodl/datasets/synthetic_utils.py index 14d32a0..9c70e6b 100644 --- a/xautodl/datasets/synthetic_utils.py +++ b/xautodl/datasets/synthetic_utils.py @@ -13,11 +13,11 @@ class UnifiedSplit: """A class to unify the split strategy.""" def __init__(self, total_num, mode): - # Training Set 60% - num_of_train = int(total_num * 0.6) - # Validation Set 20% - num_of_valid = int(total_num * 0.2) - # Test Set 20% + # Training Set 65% + num_of_train = int(total_num * 0.65) + # Validation Set 05% + num_of_valid = int(total_num * 0.05) + # Test Set 30% num_of_set = total_num - num_of_train - num_of_valid all_indexes = list(range(total_num)) if mode is None: @@ -28,6 +28,8 @@ class UnifiedSplit: self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] elif mode.lower() in ("test", "testing"): self._indexes = all_indexes[num_of_train + num_of_valid :] + elif mode.lower() in ("trainval", "trainvalidation"): + self._indexes = all_indexes[: num_of_train + num_of_valid] else: raise ValueError("Unkonwn mode of {:}".format(mode)) self._all_indexes = all_indexes