Update synthetic environment

This commit is contained in:
D-X-Y 2021-04-22 20:31:20 +08:00
parent 275831b375
commit 78ca90459c
8 changed files with 526 additions and 271 deletions

View File

@ -54,7 +54,7 @@ jobs:
run: |
python -m pip install pytest numpy
python -m pip install parameterized
python -m pip install torch
python -m pip install torch torchvision
python --version
python -m pytest ./tests/test_synthetic.py -s
shell: bash

@ -1 +1 @@
Subproject commit 33bfb2eb1388f0273d4cc492091b1f983340879b
Subproject commit f955e2ba13ae92ce5af6d28bb47d58eb6d5be249

View File

@ -4,5 +4,5 @@
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset
from .synthetic_adaptive_environment import QuadraticFunction
from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc
from .synthetic_adaptive_environment import SynAdaptiveEnv

View File

@ -2,38 +2,43 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
class QuadraticFunction:
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None):
self._params = dict(a=None, b=None, c=None)
def __init__(self, freedom: int, list_of_points=None):
self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
if list_of_points is not None:
self.fit(list_of_points)
def set(self, a, b, c):
self._params["a"] = a
self._params["b"] = b
self._params["c"] = c
def set(self, _params):
self._params = copy.deepcopy(_params)
def check_valid(self):
for key, value in self._params.items():
if value is None:
raise ValueError("The {:} is None".format(key))
@abc.abstractmethod
def __getitem__(self, x):
self.check_valid()
return self._params["a"] * x * x + self._params["b"] * x + self._params["c"]
raise NotImplementedError
@abc.abstractmethod
def _getitem(self, x):
raise NotImplementedError
def fit(
self,
list_of_points,
transf=lambda x: x,
max_iter=900,
lr_max=1.0,
verbose=False,
@ -44,16 +49,24 @@ class QuadraticFunction:
data.shape
)
x, y = data[:, 0], data[:, 1]
weights = torch.nn.Parameter(torch.Tensor(3))
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(max_iter * 0.25),
int(max_iter * 0.5),
int(max_iter * 0.75),
],
gamma=0.1,
)
if verbose:
print("The optimizer: {:}".format(optimizer))
best_loss = None
for _iter in range(max_iter):
y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2])
y_hat = self._getitem(x, weights)
loss = torch.mean(torch.abs(y - y_hat))
optimizer.zero_grad()
loss.backward()
@ -61,23 +74,105 @@ class QuadraticFunction:
lr_scheduler.step()
if verbose:
print(
"In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
_iter, max_iter, loss.item()
)
)
# Update the params
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
self._params["a"] = weights[0].item()
self._params["b"] = weights[1].item()
self._params["c"] = weights[2].item()
for i in range(self._freedom):
self._params[i] = weights[i].item()
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None):
super(QuadraticFunc, self).__init__(3, list_of_points)
def __getitem__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__,
a=self._params["a"],
b=self._params["b"],
c=self._params["c"],
a=self._params[0],
b=self._params[1],
c=self._params[2],
)
class CubicFunc(FitFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, list_of_points=None):
super(CubicFunc, self).__init__(4, list_of_points)
def __getitem__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
)
class QuarticFunc(FitFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, list_of_points=None):
super(QuarticFunc, self).__init__(5, list_of_points)
def __getitem__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
)
@ -95,28 +190,29 @@ class SynAdaptiveEnv(data.Dataset):
def __init__(
self,
num: int = 100,
num_sin_phase: int = 4,
num_sin_phase: int = 7,
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None,
):
self._amplitude_scale = QuadraticFunction(
[(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)]
self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
)
self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1)
self._total_num = num
self._period_phase_shift = QuadraticFunction()
fitting_data = []
temp_max_scalar = 2 ** num_sin_phase
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
fitting_data.append((value, math.sin(value)))
self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x))
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
# Training Set 60%
num_of_train = int(self._total_num * 0.6)
@ -135,11 +231,6 @@ class SynAdaptiveEnv(data.Dataset):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
# transformation function
self._transform = None
def set_transform(self, fn):
self._transform = fn
def __iter__(self):
self._iter_num = 0
@ -164,6 +255,14 @@ class SynAdaptiveEnv(data.Dataset):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements)".format(
name=self.__class__.__name__, cur_num=self._total_num, total=len(self)
return (
"{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=self._total_num,
total=len(self),
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -13,15 +13,15 @@ print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunction
from datasets import QuadraticFunc
from datasets import SynAdaptiveEnv
class TestQuadraticFunction(unittest.TestCase):
class TestQuadraticFunc(unittest.TestCase):
"""Test the quadratic function."""
def test_simple(self):
function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]])
function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]])
print(function)
for x in (0, 0.5, 1):
print("f({:})={:}".format(x, function[x]))
@ -31,7 +31,7 @@ class TestQuadraticFunction(unittest.TestCase):
self.assertTrue(abs(function[1] - 1) < thresh)
def test_none(self):
function = QuadraticFunction()
function = QuadraticFunc()
function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True)
print(function)
thresh = 0.2