Update sync-notebook
This commit is contained in:
		| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| @@ -29,7 +30,7 @@ class FitFunc(abc.ABC): | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
| @@ -96,7 +97,7 @@ class QuadraticFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
| @@ -118,7 +119,7 @@ class CubicFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(CubicFunc, self).__init__(4, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
| @@ -146,7 +147,7 @@ class QuarticFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuarticFunc, self).__init__(5, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
| @@ -183,13 +184,14 @@ class DynamicQuadraticFunc(FitFunc): | ||||
|         super(DynamicQuadraticFunc, self).__init__(3, list_of_points) | ||||
|         self._timestamp = None | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0][self._timestamp] * x * x | ||||
|             + self._params[1][self._timestamp] * x | ||||
|             + self._params[2][self._timestamp] | ||||
|         ) | ||||
|         a = self._params[0][self._timestamp] | ||||
|         b = self._params[1][self._timestamp] | ||||
|         c = self._params[2][self._timestamp] | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||
|         return a * x * x + b * x + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|   | ||||
| @@ -77,7 +77,7 @@ class SinGenerator(UnifiedSplit, data.Dataset): | ||||
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||
|         self._period_phase_shift = QuarticFunc(fitting_data) | ||||
|         UnifiedSplit.__init__(self, self._total_num, mode) | ||||
|         self._transform = lambda x: x | ||||
|         self._transform = None | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
| @@ -92,14 +92,20 @@ class SinGenerator(UnifiedSplit, data.Dataset): | ||||
|     def set_transform(self, transform): | ||||
|         self._transform = transform | ||||
|  | ||||
|     def transform(self, x): | ||||
|         if self._transform is None: | ||||
|             return x | ||||
|         else: | ||||
|             return self._transform(x) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index = self._indexes[index] | ||||
|         position = self._interval * index | ||||
|         value = self._amplitude_scale[position] * math.sin( | ||||
|             self._period_phase_shift[position] | ||||
|         value = self._amplitude_scale(position) * math.sin( | ||||
|             self._period_phase_shift(position) | ||||
|         ) | ||||
|         return index, position, self._transform(value) | ||||
|         return index, position, self.transform(value) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._indexes) | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path: | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import ConstantGenerator, SinGenerator | ||||
| from datasets import DynamicQuadraticFunc | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
| @@ -24,20 +25,20 @@ class TestQuadraticFunc(unittest.TestCase): | ||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function[x])) | ||||
|             print("f({:})={:}".format(x, function(x))) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function[0] - 1) < thresh) | ||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||
|         self.assertTrue(abs(function[1] - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunc() | ||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) | ||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False) | ||||
|         print(function) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function[0] - 1) < thresh) | ||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||
|         self.assertTrue(abs(function[1] - 1) < thresh) | ||||
|         thresh = 0.15 | ||||
|         self.assertTrue(abs(function(0) - 1) < thresh) | ||||
|         self.assertTrue(abs(function(0.5) - 4) < thresh) | ||||
|         self.assertTrue(abs(function(1) - 1) < thresh) | ||||
|  | ||||
|  | ||||
| class TestConstantGenerator(unittest.TestCase): | ||||
| @@ -59,3 +60,27 @@ class TestSinGenerator(unittest.TestCase): | ||||
|             assert i == idx, "First loop: {:} vs {:}".format(i, idx) | ||||
|         for i, (idx, t, x) in enumerate(dataset): | ||||
|             assert i == idx, "Second loop: {:} vs {:}".format(i, idx) | ||||
|  | ||||
|  | ||||
| class TestDynamicFunc(unittest.TestCase): | ||||
|     """Test DynamicQuadraticFunc.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicQuadraticFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = SinGenerator( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantGenerator(constant=0.9) | ||||
|         function_param[2] = SinGenerator( | ||||
|             num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||
|         ) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user