Reformulate syn-math
This commit is contained in:
		| @@ -3,11 +3,3 @@ | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
|  | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
|   | ||||
							
								
								
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								lib/datasets/math_core.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
| @@ -1,12 +1,25 @@ | ||||
| import copy | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | ||||
|     mean_generator = ComposedSinFunc() | ||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
|  | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"): | ||||
|     if version == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|     elif version == "v2": | ||||
|         mean_generator = ComposedSinFunc() | ||||
|         std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
| @@ -15,15 +28,27 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None): | ||||
|             min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode | ||||
|         ), | ||||
|     ) | ||||
|     function = DynamicQuadraticFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function_param[2] = ComposedSinFunc( | ||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|     ) | ||||
|     if version == "v1": | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|     elif version == "v2": | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|  | ||||
|     function.set(function_param) | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     # dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     dynamic_env.set_oracle_map(function) | ||||
|     return dynamic_env | ||||
|   | ||||
| @@ -57,7 +57,7 @@ class TensorContainer: | ||||
|  | ||||
|     def requires_grad_(self, requires_grad=True): | ||||
|         for tensor in self._tensors: | ||||
|           tensor.requires_grad_(requires_grad) | ||||
|             tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user