Reformulate the synthetic codes
This commit is contained in:
parent
78ca90459c
commit
731458f890
2
.github/workflows/basic_test.yml
vendored
2
.github/workflows/basic_test.yml
vendored
@ -56,5 +56,5 @@ jobs:
|
|||||||
python -m pip install parameterized
|
python -m pip install parameterized
|
||||||
python -m pip install torch torchvision
|
python -m pip install torch torchvision
|
||||||
python --version
|
python --version
|
||||||
python -m pytest ./tests/test_synthetic.py -s
|
python -m pytest ./tests/test_synthetic*.py -s
|
||||||
shell: bash
|
shell: bash
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit f955e2ba13ae92ce5af6d28bb47d58eb6d5be249
|
Subproject commit 47de7e1508536512ece82e0add082e0547cc7996
|
@ -4,5 +4,7 @@
|
|||||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||||
from .SearchDatasetWrap import SearchDataset
|
from .SearchDatasetWrap import SearchDataset
|
||||||
|
|
||||||
from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc
|
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||||
from .synthetic_adaptive_environment import SynAdaptiveEnv
|
from .math_base_funcs import DynamicQuadraticFunc
|
||||||
|
from .synthetic_utils import SinGenerator, ConstantGenerator
|
||||||
|
from .synthetic_env import SyntheticDEnv
|
||||||
|
@ -176,93 +176,31 @@ class QuarticFunc(FitFunc):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SynAdaptiveEnv(data.Dataset):
|
class DynamicQuadraticFunc(FitFunc):
|
||||||
"""The synethtic dataset for adaptive environment.
|
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||||
|
|
||||||
- x in [0, 1]
|
def __init__(self, list_of_points=None):
|
||||||
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
|
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
|
||||||
- where
|
self._timestamp = None
|
||||||
- the amplitude scale is a quadratic function of x
|
|
||||||
- the period-phase-shift is another quadratic function of x
|
|
||||||
|
|
||||||
"""
|
def __getitem__(self, x):
|
||||||
|
self.check_valid()
|
||||||
def __init__(
|
return (
|
||||||
self,
|
self._params[0][self._timestamp] * x * x
|
||||||
num: int = 100,
|
+ self._params[1][self._timestamp] * x
|
||||||
num_sin_phase: int = 7,
|
+ self._params[2][self._timestamp]
|
||||||
min_amplitude: float = 1,
|
|
||||||
max_amplitude: float = 4,
|
|
||||||
phase_shift: float = 0,
|
|
||||||
mode: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self._amplitude_scale = QuadraticFunc(
|
|
||||||
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._num_sin_phase = num_sin_phase
|
def _getitem(self, x, weights):
|
||||||
self._interval = 1.0 / (float(num) - 1)
|
raise NotImplementedError
|
||||||
self._total_num = num
|
|
||||||
|
|
||||||
fitting_data = []
|
def set_timestamp(self, timestamp):
|
||||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
self._timestamp = timestamp
|
||||||
for i in range(num_sin_phase):
|
|
||||||
value = (2 ** i) / temp_max_scalar
|
|
||||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
|
||||||
for _phase in (0, 0.25, 0.5, 0.75):
|
|
||||||
inter_value = value + (next_value - value) * _phase
|
|
||||||
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
|
||||||
self._period_phase_shift = QuarticFunc(fitting_data)
|
|
||||||
|
|
||||||
# Training Set 60%
|
|
||||||
num_of_train = int(self._total_num * 0.6)
|
|
||||||
# Validation Set 20%
|
|
||||||
num_of_valid = int(self._total_num * 0.2)
|
|
||||||
# Test Set 20%
|
|
||||||
num_of_set = self._total_num - num_of_train - num_of_valid
|
|
||||||
all_indexes = list(range(self._total_num))
|
|
||||||
if mode is None:
|
|
||||||
self._indexes = all_indexes
|
|
||||||
elif mode.lower() in ("train", "training"):
|
|
||||||
self._indexes = all_indexes[:num_of_train]
|
|
||||||
elif mode.lower() in ("valid", "validation"):
|
|
||||||
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
|
||||||
elif mode.lower() in ("test", "testing"):
|
|
||||||
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
|
||||||
else:
|
|
||||||
raise ValueError("Unkonwn mode of {:}".format(mode))
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self._iter_num = 0
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self._iter_num >= len(self):
|
|
||||||
raise StopIteration
|
|
||||||
self._iter_num += 1
|
|
||||||
return self.__getitem__(self._iter_num - 1)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
|
||||||
index = self._indexes[index]
|
|
||||||
position = self._interval * index
|
|
||||||
value = self._amplitude_scale[position] * math.sin(
|
|
||||||
self._period_phase_shift[position]
|
|
||||||
)
|
|
||||||
return index, position, value
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._indexes)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
|
||||||
"{name}({cur_num:}/{total} elements,\n"
|
name=self.__class__.__name__,
|
||||||
"amplitude={amplitude},\n"
|
a=self._params[0],
|
||||||
"period_phase_shift={period_phase_shift})".format(
|
b=self._params[1],
|
||||||
name=self.__class__.__name__,
|
c=self._params[2],
|
||||||
cur_num=self._total_num,
|
|
||||||
total=len(self),
|
|
||||||
amplitude=self._amplitude_scale,
|
|
||||||
period_phase_shift=self._period_phase_shift,
|
|
||||||
)
|
|
||||||
)
|
)
|
81
lib/datasets/synthetic_env.py
Normal file
81
lib/datasets/synthetic_env.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
|
#####################################################
|
||||||
|
import math
|
||||||
|
import abc
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Optional
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
|
||||||
|
from .synthetic_utils import UnifiedSplit
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDEnv(UnifiedSplit, data.Dataset):
|
||||||
|
"""The synethtic dynamic environment."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mean_generators: List[data.Dataset],
|
||||||
|
cov_generators: List[List[data.Dataset]],
|
||||||
|
num_per_task: int = 5000,
|
||||||
|
mode: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._ndim = len(mean_generators)
|
||||||
|
assert self._ndim == len(
|
||||||
|
cov_generators
|
||||||
|
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_generators))
|
||||||
|
for cov_generator in cov_generators:
|
||||||
|
assert self._ndim == len(
|
||||||
|
cov_generator
|
||||||
|
), "length does not match {:} vs. {:}".format(
|
||||||
|
self._ndim, len(cov_generator)
|
||||||
|
)
|
||||||
|
self._num_per_task = num_per_task
|
||||||
|
self._total_num = len(mean_generators[0])
|
||||||
|
for mean_generator in mean_generators:
|
||||||
|
assert self._total_num == len(mean_generator)
|
||||||
|
for cov_generator in cov_generators:
|
||||||
|
for cov_g in cov_generator:
|
||||||
|
assert self._total_num == len(cov_g)
|
||||||
|
|
||||||
|
self._mean_generators = mean_generators
|
||||||
|
self._cov_generators = cov_generators
|
||||||
|
|
||||||
|
UnifiedSplit.__init__(self, self._total_num, mode)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._iter_num = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._iter_num >= len(self):
|
||||||
|
raise StopIteration
|
||||||
|
self._iter_num += 1
|
||||||
|
return self.__getitem__(self._iter_num - 1)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||||
|
index = self._indexes[index]
|
||||||
|
mean_list = [generator[index][-1] for generator in self._mean_generators]
|
||||||
|
cov_matrix = [
|
||||||
|
[cov_gen[index][-1] for cov_gen in cov_generator]
|
||||||
|
for cov_generator in self._cov_generators
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = np.random.multivariate_normal(
|
||||||
|
mean_list, cov_matrix, size=self._num_per_task
|
||||||
|
)
|
||||||
|
return index, torch.Tensor(dataset)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._indexes)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
cur_num=len(self),
|
||||||
|
total=self._total_num,
|
||||||
|
ndim=self._ndim,
|
||||||
|
num_per_task=self._num_per_task,
|
||||||
|
)
|
157
lib/datasets/synthetic_utils.py
Normal file
157
lib/datasets/synthetic_utils.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
import math
|
||||||
|
import abc
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
|
||||||
|
from .math_base_funcs import QuadraticFunc, QuarticFunc
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedSplit:
|
||||||
|
"""A class to unify the split strategy."""
|
||||||
|
|
||||||
|
def __init__(self, total_num, mode):
|
||||||
|
# Training Set 60%
|
||||||
|
num_of_train = int(total_num * 0.6)
|
||||||
|
# Validation Set 20%
|
||||||
|
num_of_valid = int(total_num * 0.2)
|
||||||
|
# Test Set 20%
|
||||||
|
num_of_set = total_num - num_of_train - num_of_valid
|
||||||
|
all_indexes = list(range(total_num))
|
||||||
|
if mode is None:
|
||||||
|
self._indexes = all_indexes
|
||||||
|
elif mode.lower() in ("train", "training"):
|
||||||
|
self._indexes = all_indexes[:num_of_train]
|
||||||
|
elif mode.lower() in ("valid", "validation"):
|
||||||
|
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
||||||
|
elif mode.lower() in ("test", "testing"):
|
||||||
|
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unkonwn mode of {:}".format(mode))
|
||||||
|
self._mode = mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mode(self):
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
|
||||||
|
class SinGenerator(UnifiedSplit, data.Dataset):
|
||||||
|
"""The synethtic generator for the dynamically changing environment.
|
||||||
|
|
||||||
|
- x in [0, 1]
|
||||||
|
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
|
||||||
|
- where
|
||||||
|
- the amplitude scale is a quadratic function of x
|
||||||
|
- the period-phase-shift is another quadratic function of x
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num: int = 100,
|
||||||
|
num_sin_phase: int = 7,
|
||||||
|
min_amplitude: float = 1,
|
||||||
|
max_amplitude: float = 4,
|
||||||
|
phase_shift: float = 0,
|
||||||
|
mode: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._amplitude_scale = QuadraticFunc(
|
||||||
|
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._num_sin_phase = num_sin_phase
|
||||||
|
self._interval = 1.0 / (float(num) - 1)
|
||||||
|
self._total_num = num
|
||||||
|
|
||||||
|
fitting_data = []
|
||||||
|
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||||
|
for i in range(num_sin_phase):
|
||||||
|
value = (2 ** i) / temp_max_scalar
|
||||||
|
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||||
|
for _phase in (0, 0.25, 0.5, 0.75):
|
||||||
|
inter_value = value + (next_value - value) * _phase
|
||||||
|
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
||||||
|
self._period_phase_shift = QuarticFunc(fitting_data)
|
||||||
|
UnifiedSplit.__init__(self, self._total_num, mode)
|
||||||
|
self._transform = lambda x: x
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._iter_num = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._iter_num >= len(self):
|
||||||
|
raise StopIteration
|
||||||
|
self._iter_num += 1
|
||||||
|
return self.__getitem__(self._iter_num - 1)
|
||||||
|
|
||||||
|
def set_transform(self, transform):
|
||||||
|
self._transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||||
|
index = self._indexes[index]
|
||||||
|
position = self._interval * index
|
||||||
|
value = self._amplitude_scale[position] * math.sin(
|
||||||
|
self._period_phase_shift[position]
|
||||||
|
)
|
||||||
|
return index, position, self._transform(value)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._indexes)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}({cur_num:}/{total} elements,\n"
|
||||||
|
"amplitude={amplitude},\n"
|
||||||
|
"period_phase_shift={period_phase_shift})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
cur_num=len(self),
|
||||||
|
total=self._total_num,
|
||||||
|
amplitude=self._amplitude_scale,
|
||||||
|
period_phase_shift=self._period_phase_shift,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantGenerator(UnifiedSplit, data.Dataset):
|
||||||
|
"""The constant generator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num: int = 100,
|
||||||
|
constant: float = 0.1,
|
||||||
|
mode: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._total_num = num
|
||||||
|
self._constant = constant
|
||||||
|
UnifiedSplit.__init__(self, self._total_num, mode)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._iter_num = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._iter_num >= len(self):
|
||||||
|
raise StopIteration
|
||||||
|
self._iter_num += 1
|
||||||
|
return self.__getitem__(self._iter_num - 1)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||||
|
index = self._indexes[index]
|
||||||
|
return index, index, self._constant
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._indexes)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({cur_num:}/{total} elements)".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
cur_num=len(self),
|
||||||
|
total=self._total_num,
|
||||||
|
)
|
File diff suppressed because one or more lines are too long
112
notebooks/TOT/synthetic-data.ipynb
Normal file
112
notebooks/TOT/synthetic-data.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
30
tests/test_synthetic_env.py
Normal file
30
tests/test_synthetic_env.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
|
#####################################################
|
||||||
|
# pytest tests/test_synthetic_env.py -s #
|
||||||
|
#####################################################
|
||||||
|
import sys, random
|
||||||
|
import unittest
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||||
|
print("library path: {:}".format(lib_dir))
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
from datasets import ConstantGenerator, SinGenerator
|
||||||
|
from datasets import SyntheticDEnv
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynethicEnv(unittest.TestCase):
|
||||||
|
"""Test the synethtic environment."""
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
mean_generator = SinGenerator()
|
||||||
|
std_generator = ConstantGenerator(constant=0.5)
|
||||||
|
|
||||||
|
dataset = SyntheticDEnv([mean_generator], [[std_generator]])
|
||||||
|
print(dataset)
|
||||||
|
for timestamp, tau in dataset:
|
||||||
|
assert tau.shape == (5000, 1)
|
@ -1,7 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# pytest tests/test_synthetic.py -s #
|
# pytest tests/test_synthetic_utils.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, random
|
import sys, random
|
||||||
import unittest
|
import unittest
|
||||||
@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path:
|
|||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from datasets import QuadraticFunc
|
from datasets import QuadraticFunc
|
||||||
from datasets import SynAdaptiveEnv
|
from datasets import ConstantGenerator, SinGenerator
|
||||||
|
|
||||||
|
|
||||||
class TestQuadraticFunc(unittest.TestCase):
|
class TestQuadraticFunc(unittest.TestCase):
|
||||||
@ -40,11 +40,21 @@ class TestQuadraticFunc(unittest.TestCase):
|
|||||||
self.assertTrue(abs(function[1] - 1) < thresh)
|
self.assertTrue(abs(function[1] - 1) < thresh)
|
||||||
|
|
||||||
|
|
||||||
class TestSynAdaptiveEnv(unittest.TestCase):
|
class TestConstantGenerator(unittest.TestCase):
|
||||||
"""Test the synethtic adaptive environment."""
|
"""Test the constant data generator."""
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
dataset = SynAdaptiveEnv()
|
dataset = ConstantGenerator()
|
||||||
|
for i, (idx, t, x) in enumerate(dataset):
|
||||||
|
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
|
||||||
|
assert x == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSinGenerator(unittest.TestCase):
|
||||||
|
"""Test the synethtic data generator."""
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
dataset = SinGenerator()
|
||||||
for i, (idx, t, x) in enumerate(dataset):
|
for i, (idx, t, x) in enumerate(dataset):
|
||||||
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
|
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
|
||||||
for i, (idx, t, x) in enumerate(dataset):
|
for i, (idx, t, x) in enumerate(dataset):
|
Loading…
Reference in New Issue
Block a user