Reformulate syn-math
This commit is contained in:
		| @@ -3,11 +3,3 @@ | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
|  | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
|   | ||||
							
								
								
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
| @@ -1,12 +1,25 @@ | ||||
| import copy | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | ||||
|     if version == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|     elif version == "v2": | ||||
|         mean_generator = ComposedSinFunc() | ||||
|         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
| @@ -15,6 +28,7 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | ||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode | ||||
|         ), | ||||
|     ) | ||||
|     if version == "v1": | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
| @@ -24,6 +38,17 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|     elif version == "v2": | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|  | ||||
|     function.set(function_param) | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     # dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     dynamic_env.set_oracle_map(function) | ||||
|     return dynamic_env | ||||
|   | ||||
| @@ -13,11 +13,11 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import ConstantFunc | ||||
| from datasets import DynamicLinearFunc | ||||
| from datasets import DynamicQuadraticFunc | ||||
| from datasets import ComposedSinFunc | ||||
| from datasets.math_core import QuadraticFunc | ||||
| from datasets.math_core import ConstantFunc | ||||
| from datasets.math_core import DynamicLinearFunc | ||||
| from datasets.math_core import DynamicQuadraticFunc | ||||
| from datasets.math_core import ComposedSinFunc | ||||
|  | ||||
|  | ||||
| class TestConstantFunc(unittest.TestCase): | ||||
|   | ||||
| @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets.math_core import QuadraticFunc | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
|   | ||||
| @@ -13,8 +13,8 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import ConstantFunc, ComposedSinFunc | ||||
| from datasets import SyntheticDEnv | ||||
| from datasets.math_core import ConstantFunc, ComposedSinFunc | ||||
| from datasets.synthetic_core import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| class TestSynethicEnv(unittest.TestCase): | ||||
|   | ||||
| @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import TimeStamp | ||||
| from datasets.synthetic_core import TimeStamp | ||||
|  | ||||
|  | ||||
| class TestTimeStamp(unittest.TestCase): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user