Update baselines
This commit is contained in:
		| @@ -178,7 +178,7 @@ if __name__ == "__main__": | |||||||
|         help="The hidden dimension.", |         help="The hidden dimension.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--seq_length", type=int, default=10, help="The sequence length." |         "--seq_length", type=int, default=20, help="The sequence length." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|   | |||||||
| @@ -1,62 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### |  | ||||||
| # pytest tests/test_math_adv.py -s                  # |  | ||||||
| ##################################################### |  | ||||||
| import unittest |  | ||||||
|  |  | ||||||
| from xautodl.datasets.math_core import QuadraticFunc |  | ||||||
| from xautodl.datasets.math_core import ConstantFunc |  | ||||||
| from xautodl.datasets.math_core import DynamicLinearFunc |  | ||||||
| from xautodl.datasets.math_core import DynamicQuadraticFunc |  | ||||||
| from xautodl.datasets.math_core import ComposedSinFunc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestConstantFunc(unittest.TestCase): |  | ||||||
|     """Test the constant function.""" |  | ||||||
|  |  | ||||||
|     def test_simple(self): |  | ||||||
|         function = ConstantFunc(0.1) |  | ||||||
|         for i in range(100): |  | ||||||
|             assert function(i) == 0.1 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestDynamicFunc(unittest.TestCase): |  | ||||||
|     """Test DynamicQuadraticFunc.""" |  | ||||||
|  |  | ||||||
|     def test_simple(self): |  | ||||||
|         timestamps = 30 |  | ||||||
|         function = DynamicQuadraticFunc() |  | ||||||
|         function_param = dict() |  | ||||||
|         function_param[0] = ComposedSinFunc( |  | ||||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 |  | ||||||
|         ) |  | ||||||
|         function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|         function_param[2] = ComposedSinFunc( |  | ||||||
|             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 |  | ||||||
|         ) |  | ||||||
|         function.set(function_param) |  | ||||||
|         print(function) |  | ||||||
|  |  | ||||||
|         with self.assertRaises(TypeError) as context: |  | ||||||
|             function(0) |  | ||||||
|  |  | ||||||
|         function.set_timestamp(1) |  | ||||||
|         print(function(2)) |  | ||||||
|  |  | ||||||
|     def test_simple_linear(self): |  | ||||||
|         timestamps = 30 |  | ||||||
|         function = DynamicLinearFunc() |  | ||||||
|         function_param = dict() |  | ||||||
|         function_param[0] = ComposedSinFunc( |  | ||||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 |  | ||||||
|         ) |  | ||||||
|         function_param[1] = ConstantFunc(constant=0.9) |  | ||||||
|         function.set(function_param) |  | ||||||
|         print(function) |  | ||||||
|  |  | ||||||
|         with self.assertRaises(TypeError) as context: |  | ||||||
|             function(0) |  | ||||||
|  |  | ||||||
|         function.set_timestamp(1) |  | ||||||
|         print(function(2)) |  | ||||||
| @@ -1,33 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### |  | ||||||
| # pytest tests/test_math_base.py -s                 # |  | ||||||
| ##################################################### |  | ||||||
| import unittest |  | ||||||
|  |  | ||||||
| from xautodl.datasets.math_core import QuadraticFunc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestQuadraticFunc(unittest.TestCase): |  | ||||||
|     """Test the quadratic function.""" |  | ||||||
|  |  | ||||||
|     def test_simple(self): |  | ||||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) |  | ||||||
|         print(function) |  | ||||||
|         for x in (0, 0.5, 1): |  | ||||||
|             print("f({:})={:}".format(x, function(x))) |  | ||||||
|         thresh = 0.2 |  | ||||||
|         self.assertTrue(abs(function(0) - 1) < thresh) |  | ||||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) |  | ||||||
|         self.assertTrue(abs(function(1) - 1) < thresh) |  | ||||||
|  |  | ||||||
|     def test_none(self): |  | ||||||
|         function = QuadraticFunc() |  | ||||||
|         function.fit( |  | ||||||
|             list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False |  | ||||||
|         ) |  | ||||||
|         print(function) |  | ||||||
|         thresh = 0.15 |  | ||||||
|         self.assertTrue(abs(function(0) - 1) < thresh) |  | ||||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) |  | ||||||
|         self.assertTrue(abs(function(1) - 1) < thresh) |  | ||||||
							
								
								
									
										32
									
								
								tests/test_math_static.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								tests/test_math_static.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest tests/test_math_static.py -s               # | ||||||
|  | ##################################################### | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from xautodl.datasets.math_core import QuadraticSFunc | ||||||
|  | from xautodl.datasets.math_core import ConstantFunc | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestConstantFunc(unittest.TestCase): | ||||||
|  |     """Test the constant function.""" | ||||||
|  |  | ||||||
|  |     def test_simple(self): | ||||||
|  |         function = ConstantFunc(0.1) | ||||||
|  |         for i in range(100): | ||||||
|  |             assert function(i) == 0.1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestQuadraticSFunc(unittest.TestCase): | ||||||
|  |     """Test the quadratic function.""" | ||||||
|  |  | ||||||
|  |     def test_simple(self): | ||||||
|  |         function = QuadraticSFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||||
|  |         print(function) | ||||||
|  |         for x in (0, 0.5, 1): | ||||||
|  |             print("f({:})={:}".format(x, function(x))) | ||||||
|  |         thresh = 0.2 | ||||||
|  |         self.assertTrue(abs(function(0) - 1) < thresh) | ||||||
|  |         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||||
|  |         self.assertTrue(abs(function(1) - 1) < thresh) | ||||||
| @@ -5,7 +5,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import unittest | import unittest | ||||||
|  |  | ||||||
| from xautodl.datasets.math_core import ConstantFunc, ComposedSinFunc | from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc | ||||||
| from xautodl.datasets.synthetic_core import SyntheticDEnv | from xautodl.datasets.synthetic_core import SyntheticDEnv | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -13,7 +13,7 @@ class TestSynethicEnv(unittest.TestCase): | |||||||
|     """Test the synethtic environment.""" |     """Test the synethtic environment.""" | ||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         mean_generator = ComposedSinFunc(constant=0.1) |         mean_generator = ConstantFunc(constant=0.1) | ||||||
|         std_generator = ConstantFunc(constant=0.5) |         std_generator = ConstantFunc(constant=0.5) | ||||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) |         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||||
|         print(dataset) |         print(dataset) | ||||||
| @@ -21,7 +21,7 @@ class TestSynethicEnv(unittest.TestCase): | |||||||
|             self.assertEqual(tau.shape, (5000, 1)) |             self.assertEqual(tau.shape, (5000, 1)) | ||||||
|  |  | ||||||
|     def test_length(self): |     def test_length(self): | ||||||
|         mean_generator = ComposedSinFunc(constant=0.1) |         mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) | ||||||
|         std_generator = ConstantFunc(constant=0.5) |         std_generator = ConstantFunc(constant=0.5) | ||||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) |         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||||
|         self.assertEqual(len(dataset), 100) |         self.assertEqual(len(dataset), 100) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user