Reformulate syn-math
This commit is contained in:
		| @@ -3,11 +3,3 @@ | |||||||
| ################################################## | ################################################## | ||||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc |  | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc |  | ||||||
| from .math_adv_funcs import ConstantFunc |  | ||||||
| from .math_adv_funcs import ComposedSinFunc |  | ||||||
|  |  | ||||||
| from .synthetic_utils import TimeStamp |  | ||||||
| from .synthetic_env import SyntheticDEnv |  | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||||
|  | ##################################################### | ||||||
|  | from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||||
|  | from .math_dynamic_funcs import DynamicLinearFunc | ||||||
|  | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
|  | from .math_adv_funcs import ConstantFunc | ||||||
|  | from .math_adv_funcs import ComposedSinFunc | ||||||
| @@ -1,12 +1,25 @@ | |||||||
| import copy | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||||
|  | ##################################################### | ||||||
|  | from .synthetic_utils import TimeStamp | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
|  | from .math_dynamic_funcs import DynamicLinearFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||||
|     mean_generator = ComposedSinFunc() |  | ||||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) |  | ||||||
|  | def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | ||||||
|  |     if version == "v1": | ||||||
|  |         mean_generator = ConstantFunc(0) | ||||||
|  |         std_generator = ConstantFunc(1) | ||||||
|  |     elif version == "v2": | ||||||
|  |         mean_generator = ComposedSinFunc() | ||||||
|  |         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown version: {:}".format(version)) | ||||||
|     dynamic_env = SyntheticDEnv( |     dynamic_env = SyntheticDEnv( | ||||||
|         [mean_generator], |         [mean_generator], | ||||||
|         [[std_generator]], |         [[std_generator]], | ||||||
| @@ -15,15 +28,27 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | |||||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode |             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode | ||||||
|         ), |         ), | ||||||
|     ) |     ) | ||||||
|     function = DynamicQuadraticFunc() |     if version == "v1": | ||||||
|     function_param = dict() |         function = DynamicQuadraticFunc() | ||||||
|     function_param[0] = ComposedSinFunc( |         function_param = dict() | ||||||
|         num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 |         function_param[0] = ComposedSinFunc( | ||||||
|     ) |             num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||||
|     function_param[1] = ConstantFunc(constant=0.9) |         ) | ||||||
|     function_param[2] = ComposedSinFunc( |         function_param[1] = ConstantFunc(constant=0.9) | ||||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 |         function_param[2] = ComposedSinFunc( | ||||||
|     ) |             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||||
|  |         ) | ||||||
|  |     elif version == "v2": | ||||||
|  |         function = DynamicLinearFunc() | ||||||
|  |         function_param = dict() | ||||||
|  |         function_param[0] = ComposedSinFunc( | ||||||
|  |             amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||||
|  |         ) | ||||||
|  |         function_param[1] = ConstantFunc(constant=0.9) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown version: {:}".format(version)) | ||||||
|  |  | ||||||
|     function.set(function_param) |     function.set(function_param) | ||||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) |     # dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||||
|  |     dynamic_env.set_oracle_map(function) | ||||||
|     return dynamic_env |     return dynamic_env | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ class TensorContainer: | |||||||
|  |  | ||||||
|     def requires_grad_(self, requires_grad=True): |     def requires_grad_(self, requires_grad=True): | ||||||
|         for tensor in self._tensors: |         for tensor in self._tensors: | ||||||
|           tensor.requires_grad_(requires_grad) |             tensor.requires_grad_(requires_grad) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def tensors(self): |     def tensors(self): | ||||||
|   | |||||||
| @@ -13,11 +13,11 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from datasets import QuadraticFunc | from datasets.math_core import QuadraticFunc | ||||||
| from datasets import ConstantFunc | from datasets.math_core import ConstantFunc | ||||||
| from datasets import DynamicLinearFunc | from datasets.math_core import DynamicLinearFunc | ||||||
| from datasets import DynamicQuadraticFunc | from datasets.math_core import DynamicQuadraticFunc | ||||||
| from datasets import ComposedSinFunc | from datasets.math_core import ComposedSinFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestConstantFunc(unittest.TestCase): | class TestConstantFunc(unittest.TestCase): | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from datasets import QuadraticFunc | from datasets.math_core import QuadraticFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestQuadraticFunc(unittest.TestCase): | class TestQuadraticFunc(unittest.TestCase): | ||||||
|   | |||||||
| @@ -13,8 +13,8 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from datasets import ConstantFunc, ComposedSinFunc | from datasets.math_core import ConstantFunc, ComposedSinFunc | ||||||
| from datasets import SyntheticDEnv | from datasets.synthetic_core import SyntheticDEnv | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSynethicEnv(unittest.TestCase): | class TestSynethicEnv(unittest.TestCase): | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from datasets import TimeStamp | from datasets.synthetic_core import TimeStamp | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestTimeStamp(unittest.TestCase): | class TestTimeStamp(unittest.TestCase): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user