From 58dee23a11cf46fd8b64315161e40e7d23c3ad7f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 9 May 2021 18:53:18 +0800 Subject: [PATCH] Reformulate syn-math --- lib/datasets/__init__.py | 8 ----- lib/datasets/math_core.py | 8 +++++ lib/datasets/synthetic_core.py | 53 +++++++++++++++++++++++++--------- lib/xlayers/super_module.py | 2 +- tests/test_math_adv.py | 10 +++---- tests/test_math_base.py | 2 +- tests/test_synthetic_env.py | 4 +-- tests/test_synthetic_utils.py | 2 +- 8 files changed, 57 insertions(+), 32 deletions(-) create mode 100644 lib/datasets/math_core.py diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index 4797d38..f96d0ef 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -3,11 +3,3 @@ ################################################## from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .SearchDatasetWrap import SearchDataset - -from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc -from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc -from .math_adv_funcs import ConstantFunc -from .math_adv_funcs import ComposedSinFunc - -from .synthetic_utils import TimeStamp -from .synthetic_env import SyntheticDEnv diff --git a/lib/datasets/math_core.py b/lib/datasets/math_core.py new file mode 100644 index 0000000..6b12d88 --- /dev/null +++ b/lib/datasets/math_core.py @@ -0,0 +1,8 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # +##################################################### +from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc +from .math_dynamic_funcs import DynamicLinearFunc +from .math_dynamic_funcs import DynamicQuadraticFunc +from .math_adv_funcs import ConstantFunc +from .math_adv_funcs import ComposedSinFunc diff --git a/lib/datasets/synthetic_core.py b/lib/datasets/synthetic_core.py index ecfee69..161f1d8 100644 --- a/lib/datasets/synthetic_core.py +++ b/lib/datasets/synthetic_core.py @@ -1,12 +1,25 @@ -import copy +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # +##################################################### +from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv +from .math_dynamic_funcs import DynamicLinearFunc from .math_dynamic_funcs import DynamicQuadraticFunc from .math_adv_funcs import ConstantFunc, ComposedSinFunc -def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): - mean_generator = ComposedSinFunc() - std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) +__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] + + +def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): + if version == "v1": + mean_generator = ConstantFunc(0) + std_generator = ConstantFunc(1) + elif version == "v2": + mean_generator = ComposedSinFunc() + std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) + else: + raise ValueError("Unknown version: {:}".format(version)) dynamic_env = SyntheticDEnv( [mean_generator], [[std_generator]], @@ -15,15 +28,27 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode ), ) - function = DynamicQuadraticFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 - ) - function_param[1] = ConstantFunc(constant=0.9) - function_param[2] = ComposedSinFunc( - num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 - ) + if version == "v1": + function = DynamicQuadraticFunc() + function_param = dict() + function_param[0] = ComposedSinFunc( + num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 + ) + function_param[1] = ConstantFunc(constant=0.9) + function_param[2] = ComposedSinFunc( + num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 + ) + elif version == "v2": + function = DynamicLinearFunc() + function_param = dict() + function_param[0] = ComposedSinFunc( + amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) + ) + function_param[1] = ConstantFunc(constant=0.9) + else: + raise ValueError("Unknown version: {:}".format(version)) + function.set(function_param) - dynamic_env.set_oracle_map(copy.deepcopy(function)) + # dynamic_env.set_oracle_map(copy.deepcopy(function)) + dynamic_env.set_oracle_map(function) return dynamic_env diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index 8ee9ad9..aeed535 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -57,7 +57,7 @@ class TensorContainer: def requires_grad_(self, requires_grad=True): for tensor in self._tensors: - tensor.requires_grad_(requires_grad) + tensor.requires_grad_(requires_grad) @property def tensors(self): diff --git a/tests/test_math_adv.py b/tests/test_math_adv.py index c1ca38d..b9b85e1 100644 --- a/tests/test_math_adv.py +++ b/tests/test_math_adv.py @@ -13,11 +13,11 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import QuadraticFunc -from datasets import ConstantFunc -from datasets import DynamicLinearFunc -from datasets import DynamicQuadraticFunc -from datasets import ComposedSinFunc +from datasets.math_core import QuadraticFunc +from datasets.math_core import ConstantFunc +from datasets.math_core import DynamicLinearFunc +from datasets.math_core import DynamicQuadraticFunc +from datasets.math_core import ComposedSinFunc class TestConstantFunc(unittest.TestCase): diff --git a/tests/test_math_base.py b/tests/test_math_base.py index 5512fd5..3a33626 100644 --- a/tests/test_math_base.py +++ b/tests/test_math_base.py @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import QuadraticFunc +from datasets.math_core import QuadraticFunc class TestQuadraticFunc(unittest.TestCase): diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index ac1fe0b..8cac2fb 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -13,8 +13,8 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import ConstantFunc, ComposedSinFunc -from datasets import SyntheticDEnv +from datasets.math_core import ConstantFunc, ComposedSinFunc +from datasets.synthetic_core import SyntheticDEnv class TestSynethicEnv(unittest.TestCase): diff --git a/tests/test_synthetic_utils.py b/tests/test_synthetic_utils.py index 2f95884..5cd33a0 100644 --- a/tests/test_synthetic_utils.py +++ b/tests/test_synthetic_utils.py @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import TimeStamp +from datasets.synthetic_core import TimeStamp class TestTimeStamp(unittest.TestCase):