Reformulate syn-math
This commit is contained in:
parent
853a702926
commit
58dee23a11
@ -3,11 +3,3 @@
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
|
||||
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
from .math_adv_funcs import ComposedSinFunc
|
||||
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
|
8
lib/datasets/math_core.py
Normal file
8
lib/datasets/math_core.py
Normal file
@ -0,0 +1,8 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
from .math_adv_funcs import ComposedSinFunc
|
@ -1,12 +1,25 @@
|
||||
import copy
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
|
||||
|
||||
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None):
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
||||
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||
|
||||
|
||||
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"):
|
||||
if version == "v1":
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
elif version == "v2":
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
@ -15,15 +28,27 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None):
|
||||
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
|
||||
),
|
||||
)
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
if version == "v1":
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
elif version == "v2":
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
|
||||
function.set(function_param)
|
||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
# dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
dynamic_env.set_oracle_map(function)
|
||||
return dynamic_env
|
||||
|
@ -57,7 +57,7 @@ class TensorContainer:
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
for tensor in self._tensors:
|
||||
tensor.requires_grad_(requires_grad)
|
||||
tensor.requires_grad_(requires_grad)
|
||||
|
||||
@property
|
||||
def tensors(self):
|
||||
|
@ -13,11 +13,11 @@ print("library path: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from datasets import QuadraticFunc
|
||||
from datasets import ConstantFunc
|
||||
from datasets import DynamicLinearFunc
|
||||
from datasets import DynamicQuadraticFunc
|
||||
from datasets import ComposedSinFunc
|
||||
from datasets.math_core import QuadraticFunc
|
||||
from datasets.math_core import ConstantFunc
|
||||
from datasets.math_core import DynamicLinearFunc
|
||||
from datasets.math_core import DynamicQuadraticFunc
|
||||
from datasets.math_core import ComposedSinFunc
|
||||
|
||||
|
||||
class TestConstantFunc(unittest.TestCase):
|
||||
|
@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from datasets import QuadraticFunc
|
||||
from datasets.math_core import QuadraticFunc
|
||||
|
||||
|
||||
class TestQuadraticFunc(unittest.TestCase):
|
||||
|
@ -13,8 +13,8 @@ print("library path: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from datasets import ConstantFunc, ComposedSinFunc
|
||||
from datasets import SyntheticDEnv
|
||||
from datasets.math_core import ConstantFunc, ComposedSinFunc
|
||||
from datasets.synthetic_core import SyntheticDEnv
|
||||
|
||||
|
||||
class TestSynethicEnv(unittest.TestCase):
|
||||
|
@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from datasets import TimeStamp
|
||||
from datasets.synthetic_core import TimeStamp
|
||||
|
||||
|
||||
class TestTimeStamp(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user