Update sync-notebook
This commit is contained in:
		| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| @@ -29,7 +30,7 @@ class FitFunc(abc.ABC): | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
| @@ -96,7 +97,7 @@ class QuadraticFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
| @@ -118,7 +119,7 @@ class CubicFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(CubicFunc, self).__init__(4, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
| @@ -146,7 +147,7 @@ class QuarticFunc(FitFunc): | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuarticFunc, self).__init__(5, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
| @@ -183,13 +184,14 @@ class DynamicQuadraticFunc(FitFunc): | ||||
|         super(DynamicQuadraticFunc, self).__init__(3, list_of_points) | ||||
|         self._timestamp = None | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0][self._timestamp] * x * x | ||||
|             + self._params[1][self._timestamp] * x | ||||
|             + self._params[2][self._timestamp] | ||||
|         ) | ||||
|         a = self._params[0][self._timestamp] | ||||
|         b = self._params[1][self._timestamp] | ||||
|         c = self._params[2][self._timestamp] | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||
|         return a * x * x + b * x + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|   | ||||
| @@ -77,7 +77,7 @@ class SinGenerator(UnifiedSplit, data.Dataset): | ||||
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||
|         self._period_phase_shift = QuarticFunc(fitting_data) | ||||
|         UnifiedSplit.__init__(self, self._total_num, mode) | ||||
|         self._transform = lambda x: x | ||||
|         self._transform = None | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
| @@ -92,14 +92,20 @@ class SinGenerator(UnifiedSplit, data.Dataset): | ||||
|     def set_transform(self, transform): | ||||
|         self._transform = transform | ||||
|  | ||||
|     def transform(self, x): | ||||
|         if self._transform is None: | ||||
|             return x | ||||
|         else: | ||||
|             return self._transform(x) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index = self._indexes[index] | ||||
|         position = self._interval * index | ||||
|         value = self._amplitude_scale[position] * math.sin( | ||||
|             self._period_phase_shift[position] | ||||
|         value = self._amplitude_scale(position) * math.sin( | ||||
|             self._period_phase_shift(position) | ||||
|         ) | ||||
|         return index, position, self._transform(value) | ||||
|         return index, position, self.transform(value) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._indexes) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user