Update baselines
This commit is contained in:
		| @@ -178,7 +178,7 @@ if __name__ == "__main__": | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seq_length", type=int, default=10, help="The sequence length." | ||||
|         "--seq_length", type=int, default=20, help="The sequence length." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|   | ||||
| @@ -1,62 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_math_adv.py -s                  # | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| from xautodl.datasets.math_core import QuadraticFunc | ||||
| from xautodl.datasets.math_core import ConstantFunc | ||||
| from xautodl.datasets.math_core import DynamicLinearFunc | ||||
| from xautodl.datasets.math_core import DynamicQuadraticFunc | ||||
| from xautodl.datasets.math_core import ComposedSinFunc | ||||
|  | ||||
|  | ||||
| class TestConstantFunc(unittest.TestCase): | ||||
|     """Test the constant function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = ConstantFunc(0.1) | ||||
|         for i in range(100): | ||||
|             assert function(i) == 0.1 | ||||
|  | ||||
|  | ||||
| class TestDynamicFunc(unittest.TestCase): | ||||
|     """Test DynamicQuadraticFunc.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
|  | ||||
|     def test_simple_linear(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
| @@ -1,33 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_math_base.py -s                 # | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| from xautodl.datasets.math_core import QuadraticFunc | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function(x))) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunc() | ||||
|         function.fit( | ||||
|             list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False | ||||
|         ) | ||||
|         print(function) | ||||
|         thresh = 0.15 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
							
								
								
									
										32
									
								
								tests/test_math_static.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								tests/test_math_static.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_math_static.py -s               # | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| from xautodl.datasets.math_core import QuadraticSFunc | ||||
| from xautodl.datasets.math_core import ConstantFunc | ||||
|  | ||||
|  | ||||
| class TestConstantFunc(unittest.TestCase): | ||||
|     """Test the constant function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = ConstantFunc(0.1) | ||||
|         for i in range(100): | ||||
|             assert function(i) == 0.1 | ||||
|  | ||||
|  | ||||
| class TestQuadraticSFunc(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticSFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function(x))) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
| @@ -5,7 +5,7 @@ | ||||
| ##################################################### | ||||
| import unittest | ||||
|  | ||||
| from xautodl.datasets.math_core import ConstantFunc, ComposedSinFunc | ||||
| from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc | ||||
| from xautodl.datasets.synthetic_core import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| @@ -13,7 +13,7 @@ class TestSynethicEnv(unittest.TestCase): | ||||
|     """Test the synethtic environment.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         mean_generator = ComposedSinFunc(constant=0.1) | ||||
|         mean_generator = ConstantFunc(constant=0.1) | ||||
|         std_generator = ConstantFunc(constant=0.5) | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||
|         print(dataset) | ||||
| @@ -21,7 +21,7 @@ class TestSynethicEnv(unittest.TestCase): | ||||
|             self.assertEqual(tau.shape, (5000, 1)) | ||||
|  | ||||
|     def test_length(self): | ||||
|         mean_generator = ComposedSinFunc(constant=0.1) | ||||
|         mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3}) | ||||
|         std_generator = ConstantFunc(constant=0.5) | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||
|         self.assertEqual(len(dataset), 100) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user