Reformulate the synthetic codes

This commit is contained in:
D-X-Y 2021-04-22 23:08:43 +08:00
parent 78ca90459c
commit 731458f890
11 changed files with 568 additions and 362 deletions

View File

@ -56,5 +56,5 @@ jobs:
python -m pip install parameterized python -m pip install parameterized
python -m pip install torch torchvision python -m pip install torch torchvision
python --version python --version
python -m pytest ./tests/test_synthetic.py -s python -m pytest ./tests/test_synthetic*.py -s
shell: bash shell: bash

@ -1 +1 @@
Subproject commit f955e2ba13ae92ce5af6d28bb47d58eb6d5be249 Subproject commit 47de7e1508536512ece82e0add082e0547cc7996

View File

@ -4,5 +4,7 @@
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset from .SearchDatasetWrap import SearchDataset
from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .synthetic_adaptive_environment import SynAdaptiveEnv from .math_base_funcs import DynamicQuadraticFunc
from .synthetic_utils import SinGenerator, ConstantGenerator
from .synthetic_env import SyntheticDEnv

View File

@ -176,93 +176,31 @@ class QuarticFunc(FitFunc):
) )
class SynAdaptiveEnv(data.Dataset): class DynamicQuadraticFunc(FitFunc):
"""The synethtic dataset for adaptive environment. """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c."""
- x in [0, 1] def __init__(self, list_of_points=None):
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
- where self._timestamp = None
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
""" def __getitem__(self, x):
self.check_valid()
def __init__( return (
self, self._params[0][self._timestamp] * x * x
num: int = 100, + self._params[1][self._timestamp] * x
num_sin_phase: int = 7, + self._params[2][self._timestamp]
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None,
):
self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
) )
self._num_sin_phase = num_sin_phase def _getitem(self, x, weights):
self._interval = 1.0 / (float(num) - 1) raise NotImplementedError
self._total_num = num
fitting_data = [] def set_timestamp(self, timestamp):
temp_max_scalar = 2 ** (num_sin_phase - 1) self._timestamp = timestamp
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
# Training Set 60%
num_of_train = int(self._total_num * 0.6)
# Validation Set 20%
num_of_valid = int(self._total_num * 0.2)
# Test Set 20%
num_of_set = self._total_num - num_of_train - num_of_valid
all_indexes = list(range(self._total_num))
if mode is None:
self._indexes = all_indexes
elif mode.lower() in ("train", "training"):
self._indexes = all_indexes[:num_of_train]
elif mode.lower() in ("valid", "validation"):
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
position = self._interval * index
value = self._amplitude_scale[position] * math.sin(
self._period_phase_shift[position]
)
return index, position, value
def __len__(self):
return len(self._indexes)
def __repr__(self): def __repr__(self):
return ( return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
"{name}({cur_num:}/{total} elements,\n" name=self.__class__.__name__,
"amplitude={amplitude},\n" a=self._params[0],
"period_phase_shift={period_phase_shift})".format( b=self._params[1],
name=self.__class__.__name__, c=self._params[2],
cur_num=self._total_num,
total=len(self),
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
) )

View File

@ -0,0 +1,81 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import math
import abc
import numpy as np
from typing import List, Optional
import torch
import torch.utils.data as data
from .synthetic_utils import UnifiedSplit
class SyntheticDEnv(UnifiedSplit, data.Dataset):
"""The synethtic dynamic environment."""
def __init__(
self,
mean_generators: List[data.Dataset],
cov_generators: List[List[data.Dataset]],
num_per_task: int = 5000,
mode: Optional[str] = None,
):
self._ndim = len(mean_generators)
assert self._ndim == len(
cov_generators
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_generators))
for cov_generator in cov_generators:
assert self._ndim == len(
cov_generator
), "length does not match {:} vs. {:}".format(
self._ndim, len(cov_generator)
)
self._num_per_task = num_per_task
self._total_num = len(mean_generators[0])
for mean_generator in mean_generators:
assert self._total_num == len(mean_generator)
for cov_generator in cov_generators:
for cov_g in cov_generator:
assert self._total_num == len(cov_g)
self._mean_generators = mean_generators
self._cov_generators = cov_generators
UnifiedSplit.__init__(self, self._total_num, mode)
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
mean_list = [generator[index][-1] for generator in self._mean_generators]
cov_matrix = [
[cov_gen[index][-1] for cov_gen in cov_generator]
for cov_generator in self._cov_generators
]
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return index, torch.Tensor(dataset)
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
ndim=self._ndim,
num_per_task=self._num_per_task,
)

View File

@ -0,0 +1,157 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import QuadraticFunc, QuarticFunc
class UnifiedSplit:
"""A class to unify the split strategy."""
def __init__(self, total_num, mode):
# Training Set 60%
num_of_train = int(total_num * 0.6)
# Validation Set 20%
num_of_valid = int(total_num * 0.2)
# Test Set 20%
num_of_set = total_num - num_of_train - num_of_valid
all_indexes = list(range(total_num))
if mode is None:
self._indexes = all_indexes
elif mode.lower() in ("train", "training"):
self._indexes = all_indexes[:num_of_train]
elif mode.lower() in ("valid", "validation"):
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :]
else:
raise ValueError("Unkonwn mode of {:}".format(mode))
self._mode = mode
@property
def mode(self):
return self._mode
class SinGenerator(UnifiedSplit, data.Dataset):
"""The synethtic generator for the dynamically changing environment.
- x in [0, 1]
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- where
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
"""
def __init__(
self,
num: int = 100,
num_sin_phase: int = 7,
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None,
):
self._amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
)
self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1)
self._total_num = num
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
UnifiedSplit.__init__(self, self._total_num, mode)
self._transform = lambda x: x
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def set_transform(self, transform):
self._transform = transform
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
position = self._interval * index
value = self._amplitude_scale[position] * math.sin(
self._period_phase_shift[position]
)
return index, position, self._transform(value)
def __len__(self):
return len(self._indexes)
def __repr__(self):
return (
"{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
)
class ConstantGenerator(UnifiedSplit, data.Dataset):
"""The constant generator."""
def __init__(
self,
num: int = 100,
constant: float = 0.1,
mode: Optional[str] = None,
):
self._total_num = num
self._constant = constant
UnifiedSplit.__init__(self, self._total_num, mode)
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
return index, index, self._constant
def __len__(self):
return len(self._indexes)
def __repr__(self):
return "{name}({cur_num:}/{total} elements)".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,30 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# pytest tests/test_synthetic_env.py -s #
#####################################################
import sys, random
import unittest
import pytest
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import ConstantGenerator, SinGenerator
from datasets import SyntheticDEnv
class TestSynethicEnv(unittest.TestCase):
"""Test the synethtic environment."""
def test_simple(self):
mean_generator = SinGenerator()
std_generator = ConstantGenerator(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]])
print(dataset)
for timestamp, tau in dataset:
assert tau.shape == (5000, 1)

View File

@ -1,7 +1,7 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
##################################################### #####################################################
# pytest tests/test_synthetic.py -s # # pytest tests/test_synthetic_utils.py -s #
##################################################### #####################################################
import sys, random import sys, random
import unittest import unittest
@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunc from datasets import QuadraticFunc
from datasets import SynAdaptiveEnv from datasets import ConstantGenerator, SinGenerator
class TestQuadraticFunc(unittest.TestCase): class TestQuadraticFunc(unittest.TestCase):
@ -40,11 +40,21 @@ class TestQuadraticFunc(unittest.TestCase):
self.assertTrue(abs(function[1] - 1) < thresh) self.assertTrue(abs(function[1] - 1) < thresh)
class TestSynAdaptiveEnv(unittest.TestCase): class TestConstantGenerator(unittest.TestCase):
"""Test the synethtic adaptive environment.""" """Test the constant data generator."""
def test_simple(self): def test_simple(self):
dataset = SynAdaptiveEnv() dataset = ConstantGenerator()
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
assert x == 0.1
class TestSinGenerator(unittest.TestCase):
"""Test the synethtic data generator."""
def test_simple(self):
dataset = SinGenerator()
for i, (idx, t, x) in enumerate(dataset): for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "First loop: {:} vs {:}".format(i, idx) assert i == idx, "First loop: {:} vs {:}".format(i, idx)
for i, (idx, t, x) in enumerate(dataset): for i, (idx, t, x) in enumerate(dataset):