diff --git a/exps/GeMOSA/baselines/slbm-nof.py b/exps/GeMOSA/baselines/slbm-nof.py index d4c87eb..cfe1604 100644 --- a/exps/GeMOSA/baselines/slbm-nof.py +++ b/exps/GeMOSA/baselines/slbm-nof.py @@ -178,7 +178,7 @@ if __name__ == "__main__": help="The hidden dimension.", ) parser.add_argument( - "--seq_length", type=int, default=10, help="The sequence length." + "--seq_length", type=int, default=20, help="The sequence length." ) parser.add_argument( "--init_lr", diff --git a/tests/test_math_adv.py b/tests/test_math_adv.py deleted file mode 100644 index d31fca0..0000000 --- a/tests/test_math_adv.py +++ /dev/null @@ -1,62 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### -# pytest tests/test_math_adv.py -s # -##################################################### -import unittest - -from xautodl.datasets.math_core import QuadraticFunc -from xautodl.datasets.math_core import ConstantFunc -from xautodl.datasets.math_core import DynamicLinearFunc -from xautodl.datasets.math_core import DynamicQuadraticFunc -from xautodl.datasets.math_core import ComposedSinFunc - - -class TestConstantFunc(unittest.TestCase): - """Test the constant function.""" - - def test_simple(self): - function = ConstantFunc(0.1) - for i in range(100): - assert function(i) == 0.1 - - -class TestDynamicFunc(unittest.TestCase): - """Test DynamicQuadraticFunc.""" - - def test_simple(self): - timestamps = 30 - function = DynamicQuadraticFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 - ) - function_param[1] = ConstantFunc(constant=0.9) - function_param[2] = ComposedSinFunc( - num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 - ) - function.set(function_param) - print(function) - - with self.assertRaises(TypeError) as context: - function(0) - - function.set_timestamp(1) - print(function(2)) - - def test_simple_linear(self): - timestamps = 30 - function = DynamicLinearFunc() - function_param = dict() - function_param[0] = ComposedSinFunc( - num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 - ) - function_param[1] = ConstantFunc(constant=0.9) - function.set(function_param) - print(function) - - with self.assertRaises(TypeError) as context: - function(0) - - function.set_timestamp(1) - print(function(2)) diff --git a/tests/test_math_base.py b/tests/test_math_base.py deleted file mode 100644 index 5dbc4ce..0000000 --- a/tests/test_math_base.py +++ /dev/null @@ -1,33 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### -# pytest tests/test_math_base.py -s # -##################################################### -import unittest - -from xautodl.datasets.math_core import QuadraticFunc - - -class TestQuadraticFunc(unittest.TestCase): - """Test the quadratic function.""" - - def test_simple(self): - function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) - print(function) - for x in (0, 0.5, 1): - print("f({:})={:}".format(x, function(x))) - thresh = 0.2 - self.assertTrue(abs(function(0) - 1) < thresh) - self.assertTrue(abs(function(0.5) - 4) < thresh) - self.assertTrue(abs(function(1) - 1) < thresh) - - def test_none(self): - function = QuadraticFunc() - function.fit( - list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False - ) - print(function) - thresh = 0.15 - self.assertTrue(abs(function(0) - 1) < thresh) - self.assertTrue(abs(function(0.5) - 4) < thresh) - self.assertTrue(abs(function(1) - 1) < thresh) diff --git a/tests/test_math_static.py b/tests/test_math_static.py new file mode 100644 index 0000000..00ac383 --- /dev/null +++ b/tests/test_math_static.py @@ -0,0 +1,32 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest tests/test_math_static.py -s # +##################################################### +import unittest + +from xautodl.datasets.math_core import QuadraticSFunc +from xautodl.datasets.math_core import ConstantFunc + + +class TestConstantFunc(unittest.TestCase): + """Test the constant function.""" + + def test_simple(self): + function = ConstantFunc(0.1) + for i in range(100): + assert function(i) == 0.1 + + +class TestQuadraticSFunc(unittest.TestCase): + """Test the quadratic function.""" + + def test_simple(self): + function = QuadraticSFunc([[0, 1], [0.5, 4], [1, 1]]) + print(function) + for x in (0, 0.5, 1): + print("f({:})={:}".format(x, function(x))) + thresh = 0.2 + self.assertTrue(abs(function(0) - 1) < thresh) + self.assertTrue(abs(function(0.5) - 4) < thresh) + self.assertTrue(abs(function(1) - 1) < thresh) diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index ec10b68..b1b2f7f 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -5,7 +5,7 @@ ##################################################### import unittest -from xautodl.datasets.math_core import ConstantFunc, ComposedSinFunc +from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc from xautodl.datasets.synthetic_core import SyntheticDEnv @@ -13,7 +13,7 @@ class TestSynethicEnv(unittest.TestCase): """Test the synethtic environment.""" def test_simple(self): - mean_generator = ComposedSinFunc(constant=0.1) + mean_generator = ConstantFunc(constant=0.1) std_generator = ConstantFunc(constant=0.5) dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) print(dataset) @@ -21,7 +21,7 @@ class TestSynethicEnv(unittest.TestCase): self.assertEqual(tau.shape, (5000, 1)) def test_length(self): - mean_generator = ComposedSinFunc(constant=0.1) + mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) std_generator = ConstantFunc(constant=0.5) dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) self.assertEqual(len(dataset), 100)