Update synthetic environment

This commit is contained in:
D-X-Y 2021-04-22 20:31:20 +08:00
parent 275831b375
commit 78ca90459c
8 changed files with 526 additions and 271 deletions

View File

@ -54,7 +54,7 @@ jobs:
run: | run: |
python -m pip install pytest numpy python -m pip install pytest numpy
python -m pip install parameterized python -m pip install parameterized
python -m pip install torch python -m pip install torch torchvision
python --version python --version
python -m pytest ./tests/test_synthetic.py -s python -m pytest ./tests/test_synthetic.py -s
shell: bash shell: bash

@ -1 +1 @@
Subproject commit 33bfb2eb1388f0273d4cc492091b1f983340879b Subproject commit f955e2ba13ae92ce5af6d28bb47d58eb6d5be249

View File

@ -4,5 +4,5 @@
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset from .SearchDatasetWrap import SearchDataset
from .synthetic_adaptive_environment import QuadraticFunction from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc
from .synthetic_adaptive_environment import SynAdaptiveEnv from .synthetic_adaptive_environment import SynAdaptiveEnv

View File

@ -2,38 +2,43 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
##################################################### #####################################################
import math import math
import abc
import numpy as np import numpy as np
from typing import Optional from typing import Optional
import torch import torch
import torch.utils.data as data import torch.utils.data as data
class QuadraticFunction: class FitFunc(abc.ABC):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" """The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None): def __init__(self, freedom: int, list_of_points=None):
self._params = dict(a=None, b=None, c=None) self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
if list_of_points is not None: if list_of_points is not None:
self.fit(list_of_points) self.fit(list_of_points)
def set(self, a, b, c): def set(self, _params):
self._params["a"] = a self._params = copy.deepcopy(_params)
self._params["b"] = b
self._params["c"] = c
def check_valid(self): def check_valid(self):
for key, value in self._params.items(): for key, value in self._params.items():
if value is None: if value is None:
raise ValueError("The {:} is None".format(key)) raise ValueError("The {:} is None".format(key))
@abc.abstractmethod
def __getitem__(self, x): def __getitem__(self, x):
self.check_valid() raise NotImplementedError
return self._params["a"] * x * x + self._params["b"] * x + self._params["c"]
@abc.abstractmethod
def _getitem(self, x):
raise NotImplementedError
def fit( def fit(
self, self,
list_of_points, list_of_points,
transf=lambda x: x,
max_iter=900, max_iter=900,
lr_max=1.0, lr_max=1.0,
verbose=False, verbose=False,
@ -44,16 +49,24 @@ class QuadraticFunction:
data.shape data.shape
) )
x, y = data[:, 0], data[:, 1] x, y = data[:, 0], data[:, 1]
weights = torch.nn.Parameter(torch.Tensor(3)) weights = torch.nn.Parameter(torch.Tensor(self._freedom))
torch.nn.init.normal_(weights, mean=0.0, std=1.0) torch.nn.init.normal_(weights, mean=0.0, std=1.0)
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(max_iter * 0.25),
int(max_iter * 0.5),
int(max_iter * 0.75),
],
gamma=0.1,
)
if verbose: if verbose:
print("The optimizer: {:}".format(optimizer)) print("The optimizer: {:}".format(optimizer))
best_loss = None best_loss = None
for _iter in range(max_iter): for _iter in range(max_iter):
y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2]) y_hat = self._getitem(x, weights)
loss = torch.mean(torch.abs(y - y_hat)) loss = torch.mean(torch.abs(y - y_hat))
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
@ -61,23 +74,105 @@ class QuadraticFunction:
lr_scheduler.step() lr_scheduler.step()
if verbose: if verbose:
print( print(
"In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format( "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
_iter, max_iter, loss.item() _iter, max_iter, loss.item()
) )
) )
# Update the params # Update the params
if best_loss is None or best_loss > loss.item(): if best_loss is None or best_loss > loss.item():
best_loss = loss.item() best_loss = loss.item()
self._params["a"] = weights[0].item() for i in range(self._freedom):
self._params["b"] = weights[1].item() self._params[i] = weights[i].item()
self._params["c"] = weights[2].item()
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None):
super(QuadraticFunc, self).__init__(3, list_of_points)
def __getitem__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self): def __repr__(self):
return "{name}(y = {a} * x^2 + {b} * x + {c})".format( return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params["a"], a=self._params[0],
b=self._params["b"], b=self._params[1],
c=self._params["c"], c=self._params[2],
)
class CubicFunc(FitFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, list_of_points=None):
super(CubicFunc, self).__init__(4, list_of_points)
def __getitem__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
)
class QuarticFunc(FitFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, list_of_points=None):
super(QuarticFunc, self).__init__(5, list_of_points)
def __getitem__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
) )
@ -95,28 +190,29 @@ class SynAdaptiveEnv(data.Dataset):
def __init__( def __init__(
self, self,
num: int = 100, num: int = 100,
num_sin_phase: int = 4, num_sin_phase: int = 7,
min_amplitude: float = 1, min_amplitude: float = 1,
max_amplitude: float = 4, max_amplitude: float = 4,
phase_shift: float = 0, phase_shift: float = 0,
mode: Optional[str] = None, mode: Optional[str] = None,
): ):
self._amplitude_scale = QuadraticFunction( self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)] [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
) )
self._num_sin_phase = num_sin_phase self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1) self._interval = 1.0 / (float(num) - 1)
self._total_num = num self._total_num = num
self._period_phase_shift = QuadraticFunction()
fitting_data = [] fitting_data = []
temp_max_scalar = 2 ** num_sin_phase temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase): for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar value = (2 ** i) / temp_max_scalar
fitting_data.append((value, math.sin(value))) next_value = (2 ** (i + 1)) / temp_max_scalar
self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x)) for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
# Training Set 60% # Training Set 60%
num_of_train = int(self._total_num * 0.6) num_of_train = int(self._total_num * 0.6)
@ -135,11 +231,6 @@ class SynAdaptiveEnv(data.Dataset):
self._indexes = all_indexes[num_of_train + num_of_valid :] self._indexes = all_indexes[num_of_train + num_of_valid :]
else: else:
raise ValueError("Unkonwn mode of {:}".format(mode)) raise ValueError("Unkonwn mode of {:}".format(mode))
# transformation function
self._transform = None
def set_transform(self, fn):
self._transform = fn
def __iter__(self): def __iter__(self):
self._iter_num = 0 self._iter_num = 0
@ -164,6 +255,14 @@ class SynAdaptiveEnv(data.Dataset):
return len(self._indexes) return len(self._indexes)
def __repr__(self): def __repr__(self):
return "{name}({cur_num:}/{total} elements)".format( return (
name=self.__class__.__name__, cur_num=self._total_num, total=len(self) "{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=self._total_num,
total=len(self),
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
) )

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -13,15 +13,15 @@ print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunction from datasets import QuadraticFunc
from datasets import SynAdaptiveEnv from datasets import SynAdaptiveEnv
class TestQuadraticFunction(unittest.TestCase): class TestQuadraticFunc(unittest.TestCase):
"""Test the quadratic function.""" """Test the quadratic function."""
def test_simple(self): def test_simple(self):
function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]]) function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]])
print(function) print(function)
for x in (0, 0.5, 1): for x in (0, 0.5, 1):
print("f({:})={:}".format(x, function[x])) print("f({:})={:}".format(x, function[x]))
@ -31,7 +31,7 @@ class TestQuadraticFunction(unittest.TestCase):
self.assertTrue(abs(function[1] - 1) < thresh) self.assertTrue(abs(function[1] - 1) < thresh)
def test_none(self): def test_none(self):
function = QuadraticFunction() function = QuadraticFunc()
function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True)
print(function) print(function)
thresh = 0.2 thresh = 0.2