Reformulate Math Functions
This commit is contained in:
		
							
								
								
									
										52
									
								
								tests/test_math_adv.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests/test_math_adv.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_math_adv.py -s                  # | ||||
| ##################################################### | ||||
| import sys, random | ||||
| import unittest | ||||
| import pytest | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import ConstantFunc | ||||
| from datasets import DynamicQuadraticFunc | ||||
| from datasets import ComposedSinFunc | ||||
|  | ||||
|  | ||||
| class TestConstantFunc(unittest.TestCase): | ||||
|     """Test the constant function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = ConstantFunc(0.1) | ||||
|         for i in range(100): | ||||
|             assert function(i) == 0.1 | ||||
|  | ||||
|  | ||||
| class TestDynamicFunc(unittest.TestCase): | ||||
|     """Test DynamicQuadraticFunc.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function_param[2] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
							
								
								
									
										41
									
								
								tests/test_math_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								tests/test_math_base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_math_base.py -s                 # | ||||
| ##################################################### | ||||
| import sys, random | ||||
| import unittest | ||||
| import pytest | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function(x))) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunc() | ||||
|         function.fit( | ||||
|             list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False | ||||
|         ) | ||||
|         print(function) | ||||
|         thresh = 0.15 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
| @@ -79,7 +79,7 @@ def test_super_sequential_v1(): | ||||
|         super_core.SuperSimpleNorm(1, 1), | ||||
|         torch.nn.ReLU(), | ||||
|         super_core.SuperLinear(10, 10), | ||||
|         super_core.SuperReLU() | ||||
|         super_core.SuperReLU(), | ||||
|     ) | ||||
|     inputs = torch.rand(10, 10) | ||||
|     print(model) | ||||
|   | ||||
| @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import ConstantGenerator, SinGenerator | ||||
| from datasets import ConstantFunc, ComposedSinFunc | ||||
| from datasets import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| @@ -21,10 +21,10 @@ class TestSynethicEnv(unittest.TestCase): | ||||
|     """Test the synethtic environment.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         mean_generator = SinGenerator() | ||||
|         std_generator = ConstantGenerator(constant=0.5) | ||||
|         mean_generator = ComposedSinFunc(constant=0.1) | ||||
|         std_generator = ConstantFunc(constant=0.5) | ||||
|  | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]]) | ||||
|         dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) | ||||
|         print(dataset) | ||||
|         for timestamp, tau in dataset: | ||||
|             assert tau.shape == (5000, 1) | ||||
|   | ||||
| @@ -13,74 +13,19 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import ConstantGenerator, SinGenerator | ||||
| from datasets import DynamicQuadraticFunc | ||||
| from datasets import TimeStamp | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
| class TestTimeStamp(unittest.TestCase): | ||||
|     """Test the timestamp generator.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function(x))) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunc() | ||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False) | ||||
|         print(function) | ||||
|         thresh = 0.15 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|  | ||||
| class TestConstantGenerator(unittest.TestCase): | ||||
|     """Test the constant data generator.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         dataset = ConstantGenerator() | ||||
|         for i, (idx, t, x) in enumerate(dataset): | ||||
|             assert i == idx, "First loop: {:} vs {:}".format(i, idx) | ||||
|             assert x == 0.1 | ||||
|  | ||||
|  | ||||
| class TestSinGenerator(unittest.TestCase): | ||||
|     """Test the synethtic data generator.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         dataset = SinGenerator() | ||||
|         for i, (idx, t, x) in enumerate(dataset): | ||||
|             assert i == idx, "First loop: {:} vs {:}".format(i, idx) | ||||
|         for i, (idx, t, x) in enumerate(dataset): | ||||
|             assert i == idx, "Second loop: {:} vs {:}".format(i, idx) | ||||
|  | ||||
|  | ||||
| class TestDynamicFunc(unittest.TestCase): | ||||
|     """Test DynamicQuadraticFunc.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = SinGenerator( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantGenerator(constant=0.9) | ||||
|         function_param[2] = SinGenerator( | ||||
|             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
|         for mode in (None, "train", "valid", "test"): | ||||
|             generator = TimeStamp(0, 1) | ||||
|             print(generator) | ||||
|             for idx, (i, xtime) in enumerate(generator): | ||||
|                 self.assertTrue(i == idx) | ||||
|                 if idx == 0: | ||||
|                     self.assertTrue(xtime == 0) | ||||
|                 if idx + 1 == len(generator): | ||||
|                     self.assertTrue(abs(xtime - 1) < 1e-8) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user