Update sync-notebook
This commit is contained in:
		| @@ -3,6 +3,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import math | import math | ||||||
| import abc | import abc | ||||||
|  | import copy | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Optional | from typing import Optional | ||||||
| import torch | import torch | ||||||
| @@ -29,7 +30,7 @@ class FitFunc(abc.ABC): | |||||||
|                 raise ValueError("The {:} is None".format(key)) |                 raise ValueError("The {:} is None".format(key)) | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def __getitem__(self, x): |     def __call__(self, x): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
| @@ -96,7 +97,7 @@ class QuadraticFunc(FitFunc): | |||||||
|     def __init__(self, list_of_points=None): |     def __init__(self, list_of_points=None): | ||||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) |         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||||
|  |  | ||||||
|     def __getitem__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] |         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||||
|  |  | ||||||
| @@ -118,7 +119,7 @@ class CubicFunc(FitFunc): | |||||||
|     def __init__(self, list_of_points=None): |     def __init__(self, list_of_points=None): | ||||||
|         super(CubicFunc, self).__init__(4, list_of_points) |         super(CubicFunc, self).__init__(4, list_of_points) | ||||||
|  |  | ||||||
|     def __getitem__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         return ( |         return ( | ||||||
|             self._params[0] * x ** 3 |             self._params[0] * x ** 3 | ||||||
| @@ -146,7 +147,7 @@ class QuarticFunc(FitFunc): | |||||||
|     def __init__(self, list_of_points=None): |     def __init__(self, list_of_points=None): | ||||||
|         super(QuarticFunc, self).__init__(5, list_of_points) |         super(QuarticFunc, self).__init__(5, list_of_points) | ||||||
|  |  | ||||||
|     def __getitem__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         return ( |         return ( | ||||||
|             self._params[0] * x ** 4 |             self._params[0] * x ** 4 | ||||||
| @@ -183,13 +184,14 @@ class DynamicQuadraticFunc(FitFunc): | |||||||
|         super(DynamicQuadraticFunc, self).__init__(3, list_of_points) |         super(DynamicQuadraticFunc, self).__init__(3, list_of_points) | ||||||
|         self._timestamp = None |         self._timestamp = None | ||||||
|  |  | ||||||
|     def __getitem__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         return ( |         a = self._params[0][self._timestamp] | ||||||
|             self._params[0][self._timestamp] * x * x |         b = self._params[1][self._timestamp] | ||||||
|             + self._params[1][self._timestamp] * x |         c = self._params[2][self._timestamp] | ||||||
|             + self._params[2][self._timestamp] |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|         ) |         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||||
|  |         return a * x * x + b * x + c | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |     def _getitem(self, x, weights): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
| @@ -77,7 +77,7 @@ class SinGenerator(UnifiedSplit, data.Dataset): | |||||||
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) |                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||||
|         self._period_phase_shift = QuarticFunc(fitting_data) |         self._period_phase_shift = QuarticFunc(fitting_data) | ||||||
|         UnifiedSplit.__init__(self, self._total_num, mode) |         UnifiedSplit.__init__(self, self._total_num, mode) | ||||||
|         self._transform = lambda x: x |         self._transform = None | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         self._iter_num = 0 |         self._iter_num = 0 | ||||||
| @@ -92,14 +92,20 @@ class SinGenerator(UnifiedSplit, data.Dataset): | |||||||
|     def set_transform(self, transform): |     def set_transform(self, transform): | ||||||
|         self._transform = transform |         self._transform = transform | ||||||
|  |  | ||||||
|  |     def transform(self, x): | ||||||
|  |         if self._transform is None: | ||||||
|  |             return x | ||||||
|  |         else: | ||||||
|  |             return self._transform(x) | ||||||
|  |  | ||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) |         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||||
|         index = self._indexes[index] |         index = self._indexes[index] | ||||||
|         position = self._interval * index |         position = self._interval * index | ||||||
|         value = self._amplitude_scale[position] * math.sin( |         value = self._amplitude_scale(position) * math.sin( | ||||||
|             self._period_phase_shift[position] |             self._period_phase_shift(position) | ||||||
|         ) |         ) | ||||||
|         return index, position, self._transform(value) |         return index, position, self.transform(value) | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._indexes) |         return len(self._indexes) | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path: | |||||||
|  |  | ||||||
| from datasets import QuadraticFunc | from datasets import QuadraticFunc | ||||||
| from datasets import ConstantGenerator, SinGenerator | from datasets import ConstantGenerator, SinGenerator | ||||||
|  | from datasets import DynamicQuadraticFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestQuadraticFunc(unittest.TestCase): | class TestQuadraticFunc(unittest.TestCase): | ||||||
| @@ -24,20 +25,20 @@ class TestQuadraticFunc(unittest.TestCase): | |||||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) |         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||||
|         print(function) |         print(function) | ||||||
|         for x in (0, 0.5, 1): |         for x in (0, 0.5, 1): | ||||||
|             print("f({:})={:}".format(x, function[x])) |             print("f({:})={:}".format(x, function(x))) | ||||||
|         thresh = 0.2 |         thresh = 0.2 | ||||||
|         self.assertTrue(abs(function[0] - 1) < thresh) |         self.assertTrue(abs(function(0) - 1) < thresh) | ||||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) |         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||||
|         self.assertTrue(abs(function[1] - 1) < thresh) |         self.assertTrue(abs(function(1) - 1) < thresh) | ||||||
|  |  | ||||||
|     def test_none(self): |     def test_none(self): | ||||||
|         function = QuadraticFunc() |         function = QuadraticFunc() | ||||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) |         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False) | ||||||
|         print(function) |         print(function) | ||||||
|         thresh = 0.2 |         thresh = 0.15 | ||||||
|         self.assertTrue(abs(function[0] - 1) < thresh) |         self.assertTrue(abs(function(0) - 1) < thresh) | ||||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) |         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||||
|         self.assertTrue(abs(function[1] - 1) < thresh) |         self.assertTrue(abs(function(1) - 1) < thresh) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestConstantGenerator(unittest.TestCase): | class TestConstantGenerator(unittest.TestCase): | ||||||
| @@ -59,3 +60,27 @@ class TestSinGenerator(unittest.TestCase): | |||||||
|             assert i == idx, "First loop: {:} vs {:}".format(i, idx) |             assert i == idx, "First loop: {:} vs {:}".format(i, idx) | ||||||
|         for i, (idx, t, x) in enumerate(dataset): |         for i, (idx, t, x) in enumerate(dataset): | ||||||
|             assert i == idx, "Second loop: {:} vs {:}".format(i, idx) |             assert i == idx, "Second loop: {:} vs {:}".format(i, idx) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestDynamicFunc(unittest.TestCase): | ||||||
|  |     """Test DynamicQuadraticFunc.""" | ||||||
|  |  | ||||||
|  |     def test_simple(self): | ||||||
|  |         timestamps = 30 | ||||||
|  |         function = DynamicQuadraticFunc() | ||||||
|  |         function_param = dict() | ||||||
|  |         function_param[0] = SinGenerator( | ||||||
|  |             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||||
|  |         ) | ||||||
|  |         function_param[1] = ConstantGenerator(constant=0.9) | ||||||
|  |         function_param[2] = SinGenerator( | ||||||
|  |             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||||
|  |         ) | ||||||
|  |         function.set(function_param) | ||||||
|  |         print(function) | ||||||
|  |  | ||||||
|  |         with self.assertRaises(TypeError) as context: | ||||||
|  |             function(0) | ||||||
|  |  | ||||||
|  |         function.set_timestamp(1) | ||||||
|  |         print(function(2)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user