Update the sync data v1

This commit is contained in:
D-X-Y 2021-05-24 13:06:10 +08:00
parent da2575cc6c
commit 3ee0d348af
17 changed files with 228 additions and 274 deletions

View File

@ -56,12 +56,16 @@ jobs:
python -m pytest ./tests/test_basic_space.py -s python -m pytest ./tests/test_basic_space.py -s
shell: bash shell: bash
- name: Test Synthetic Data - name: Test Math
run: | run: |
python -m pip install pytest numpy python -m pip install pytest numpy
python -m pip install parameterized python -m pip install parameterized
python -m pip install torch torchvision python -m pip install torch torchvision
python --version python --version
python -m pytest ./tests/test_math*.py -s python -m pytest ./tests/test_math*.py -s
shell: bash
- name: Test Synthetic Data
run: |
python -m pytest ./tests/test_synthetic*.py -s python -m pytest ./tests/test_synthetic*.py -s
shell: bash shell: bash

View File

@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
def main(args): def main(args):
logger, env_info, model_kwargs = lfna_setup(args) logger, model_kwargs = lfna_setup(args)
train_env = get_synthetic_env(mode="train", version=args.env_version) train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version) all_env = get_synthetic_env(mode=None, version=args.env_version)

View File

@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env
def lfna_setup(args): def lfna_setup(args):
prepare_seed(args.rand_seed) prepare_seed(args.rand_seed)
logger = prepare_logger(args) logger = prepare_logger(args)
cache_path = (
logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version)
).resolve()
if cache_path.exists():
env_info = torch.load(cache_path)
else:
env_info = dict()
dynamic_env = get_synthetic_env(version=args.env_version)
env_info["total"] = len(dynamic_env)
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
env_info["{:}-timestamp".format(idx)] = timestamp
env_info["{:}-x".format(idx)] = _allx
env_info["{:}-y".format(idx)] = _ally
env_info["dynamic_env"] = dynamic_env
torch.save(env_info, cache_path)
"""
model_kwargs = dict(
config=dict(model_type="simple_mlp"),
input_dim=1,
output_dim=1,
hidden_dim=args.hidden_dim,
act_cls="leaky_relu",
norm_cls="identity",
)
"""
model_kwargs = dict( model_kwargs = dict(
config=dict(model_type="norm_mlp"), config=dict(model_type="norm_mlp"),
input_dim=1, input_dim=1,
@ -46,7 +19,7 @@ def lfna_setup(args):
act_cls="gelu", act_cls="gelu",
norm_cls="layer_norm_1d", norm_cls="layer_norm_1d",
) )
return logger, env_info, model_kwargs return logger, model_kwargs
def train_model(model, dataset, lr, epochs): def train_model(model, dataset, lr, epochs):

View File

@ -20,14 +20,13 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as ticker import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from models.xcore import get_model from xautodl.models.xcore import get_model
from datasets.synthetic_core import get_synthetic_env from xautodl.datasets.synthetic_core import get_synthetic_env
from utils.temp_sync import optimize_fn, evaluate_fn from xautodl.procedures.metric_utils import MSEMetric
from procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
@ -181,10 +180,17 @@ def compare_cl(save_dir):
def visualize_env(save_dir, version): def visualize_env(save_dir, version):
save_dir = Path(str(save_dir)) save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True) for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version) dynamic_env = get_synthetic_env(version=version)
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp # min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400 dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
@ -201,21 +207,18 @@ def visualize_env(save_dir, version):
tick.label.set_rotation(10) tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks(): for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap) tick.label.set_fontsize(LabelSize - font_gap)
if version == "v1": cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_xlim(-2, 2) cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.set_ylim(-8, 8)
elif version == "v2":
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
cur_ax.legend(loc=1, fontsize=LegendFontsize) cur_ax.legend(loc=1, fontsize=LegendFontsize)
save_path = save_dir / "v{:}-{:05d}".format(version, idx) pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all") plt.close("all")
save_dir = save_dir.resolve() save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, version=version xdir=save_dir / "png", version=version
) )
print(base_cmd) print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
@ -371,7 +374,7 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl")) # compare_cl(os.path.join(args.save_dir, "compare-cl"))

View File

@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config
# NAS-Bench-201 related module or function # NAS-Bench-201 related module or function
from xautodl.models import CellStructure, get_cell_based_tiny_net from xautodl.models import CellStructure, get_cell_based_tiny_net
from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders from xautodl.procedures import (
bench_pure_evaluate as pure_evaluate,
get_nas_bench_loaders,
)
from nas_201_api import NASBench201API, ArchResults, ResultsCount from nas_201_api import NASBench201API, ArchResults, ResultsCount
api = NASBench201API( api = NASBench201API(

View File

@ -0,0 +1,21 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
# python test-dynamic.py
#####################################################
import sys
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.datasets.math_core import ConstantFunc
from xautodl.datasets.math_core import GaussianDGenerator
mean_generator = ConstantFunc(0)
cov_generator = ConstantFunc(1)
generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1))
generator(0, 10)

View File

@ -19,9 +19,11 @@ import seaborn as sns
matplotlib.use("agg") matplotlib.use("agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from log_utils import time_string from log_utils import time_string
from nats_bench import create from nats_bench import create
from models import get_cell_based_tiny_net from models import get_cell_based_tiny_net

View File

@ -3,11 +3,7 @@ from copy import deepcopy
import torchvision.models as models import torchvision.models as models
from pathlib import Path from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() from xautodl.utils import weight_watcher
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import weight_watcher
def main(): def main():

View File

@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc
class ConstantFunc(FitFunc): class ConstantFunc(FitFunc):
"""The constant function: f(x) = c.""" """The constant function: f(x) = c."""
def __init__(self, constant=None): def __init__(self, constant=None, xstr="x"):
param = dict() param = dict()
param[0] = constant param[0] = constant
super(ConstantFunc, self).__init__(0, None, param) super(ConstantFunc, self).__init__(0, None, param, xstr)
def __call__(self, x): def __call__(self, x):
self.check_valid() self.check_valid()
@ -37,6 +37,34 @@ class ConstantFunc(FitFunc):
class ComposedSinFunc(FitFunc): class ComposedSinFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedSinFuncV2(FitFunc):
"""The composed sin function that outputs: """The composed sin function that outputs:
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- the amplitude scale is a quadratic function of x - the amplitude scale is a quadratic function of x
@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(ComposedSinFunc, self).__init__(0, None) super(ComposedSinFuncV2, self).__init__(0, None)
self.fit(**kwargs) self.fit(**kwargs)
def __call__(self, x): def __call__(self, x):

View File

@ -5,15 +5,13 @@ import math
import abc import abc
import copy import copy
import numpy as np import numpy as np
from typing import Optional
import torch import torch
import torch.utils.data as data
class FitFunc(abc.ABC): class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c.""" """The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, list_of_points=None, params=None): def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
self._params = dict() self._params = dict()
for i in range(freedom): for i in range(freedom):
self._params[i] = None self._params[i] = None
@ -24,6 +22,7 @@ class FitFunc(abc.ABC):
self.fit(list_of_points=list_of_points) self.fit(list_of_points=list_of_points)
if params is not None: if params is not None:
self.set(params) self.set(params)
self._xstr = str(xstr)
def set(self, params): def set(self, params):
self._params = copy.deepcopy(params) self._params = copy.deepcopy(params)
@ -33,6 +32,13 @@ class FitFunc(abc.ABC):
if value is None: if value is None:
raise ValueError("The {:} is None".format(key)) raise ValueError("The {:} is None".format(key))
@property
def xstr(self):
return self._xstr
def reset_xstr(self, xstr):
self._xstr = str(xstr)
@abc.abstractmethod @abc.abstractmethod
def __call__(self, x): def __call__(self, x):
raise NotImplementedError raise NotImplementedError
@ -106,8 +112,8 @@ class FitFunc(abc.ABC):
class LinearFunc(FitFunc): class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b.""" """The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None): def __init__(self, list_of_points=None, params=None, xstr="x"):
super(LinearFunc, self).__init__(2, list_of_points, params) super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
def __call__(self, x): def __call__(self, x):
self.check_valid() self.check_valid()
@ -117,18 +123,19 @@ class LinearFunc(FitFunc):
return weights[0] * x + weights[1] return weights[0] * x + weights[1]
def __repr__(self): def __repr__(self):
return "{name}({a} * x + {b})".format( return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
x=self.xstr,
) )
class QuadraticFunc(FitFunc): class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" """The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None, params=None): def __init__(self, list_of_points=None, params=None, xstr="x"):
super(QuadraticFunc, self).__init__(3, list_of_points, params) super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
def __call__(self, x): def __call__(self, x):
self.check_valid() self.check_valid()
@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc):
return weights[0] * x * x + weights[1] * x + weights[2] return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self): def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format( return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
c=self._params[2], c=self._params[2],
x=self.xstr,
) )
@ -165,12 +173,13 @@ class CubicFunc(FitFunc):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self): def __repr__(self):
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
c=self._params[2], c=self._params[2],
d=self._params[3], d=self._params[3],
x=self.xstr,
) )

View File

@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc from .math_adv_funcs import ComposedSinFunc
from .math_dynamic_generator import GaussianDGenerator

View File

@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc
class DynamicFunc(FitFunc): class DynamicFunc(FitFunc):
"""The dynamic quadratic function, where each param is a function.""" """The dynamic quadratic function, where each param is a function."""
def __init__(self, freedom: int, params=None): def __init__(self, freedom: int, params=None, xstr="x"):
super(DynamicFunc, self).__init__(freedom, None, params) if params is not None:
self._timestamp = None for param in params:
param.reset_xstr("t") if isinstance(param, FitFunc) else None
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
def __call__(self, x, timestamp=None): def __call__(self, x, timestamp):
raise NotImplementedError raise NotImplementedError
def _getitem(self, x, weights): def _getitem(self, x, weights):
raise NotImplementedError raise NotImplementedError
def set_timestamp(self, timestamp): def noise_call(self, x, timestamp, std):
self._timestamp = timestamp
def noise_call(self, x, timestamp=None, std=0.1):
clean_y = self.__call__(x, timestamp) clean_y = self.__call__(x, timestamp)
if isinstance(clean_y, np.ndarray): if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc):
The a and b is a function of timestamp. The a and b is a function of timestamp.
""" """
def __init__(self, params=None): def __init__(self, params=None, xstr="x"):
super(DynamicLinearFunc, self).__init__(3, params) super(DynamicLinearFunc, self).__init__(3, params, xstr)
def __call__(self, x, timestamp=None): def __call__(self, x, timestamp):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
a = self._params[0](timestamp) a = self._params[0](timestamp)
b = self._params[1](timestamp) b = self._params[1](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc):
return a * x + b return a * x + b
def __repr__(self): def __repr__(self):
return "{name}({a} * x + {b}, timestamp={timestamp})".format( return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
timestamp=self._timestamp, x=self.xstr,
) )

View File

@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import abc
import numpy as np
def assert_list_tuple(x):
assert isinstance(x, (list, tuple))
return len(x)
class DynamicGenerator(abc.ABC):
"""The dynamic quadratic function, where each param is a function."""
def __init__(self):
self._ndim = None
def __call__(self, time, num):
raise NotImplementedError
class GaussianDGenerator(DynamicGenerator):
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
super(GaussianDGenerator, self).__init__()
self._ndim = assert_list_tuple(mean_functors)
assert self._ndim == len(
cov_functors
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
assert_list_tuple(cov_functors)
for cov_functor in cov_functors:
assert self._ndim == assert_list_tuple(
cov_functor
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
assert (
isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1]
)
self._mean_functors = mean_functors
self._cov_functors = cov_functors
if trunc is not None:
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
self._trunc = trunc
def __call__(self, time, num):
mean_list = [functor(time) for functor in self._mean_functors]
cov_matrix = [
[abs(cov_gen(time)) for cov_gen in cov_functor]
for cov_functor in self._cov_functors
]
values = np.random.multivariate_normal(mean_list, cov_matrix, size=num)
if self._trunc is not None:
np.clip(values, self._trunc[0], self._trunc[1], out=values)
return values
def __repr__(self):
return "{name}({ndim} dims, trunc={trunc})".format(
name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc
)

View File

@ -1,13 +1,14 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
##################################################### #####################################################
import math
from .synthetic_utils import TimeStamp from .synthetic_utils import TimeStamp
from .synthetic_env import EnvSampler
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc from .math_core import LinearFunc
from .math_core import DynamicLinearFunc from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc from .math_core import ConstantFunc, ComposedSinFunc
from .math_core import GaussianDGenerator
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
if version == "v1": if version == "v1":
mean_generator = ConstantFunc(0) mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1) std_generator = ConstantFunc(1)
elif version == "v2": data_generator = GaussianDGenerator(
mean_generator = ComposedSinFunc() [mean_generator], [[std_generator]], (-2, 2)
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
else:
raise ValueError("Unknown version: {:}".format(version))
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=dict(
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
),
)
if version == "v1":
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(3.0),
num_sin_phase=9,
sin_speed_use_power=False,
) )
function_param[1] = ConstantFunc(constant=0.9) time_generator = TimeStamp(
elif version == "v2": min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
) )
function_param[1] = ConstantFunc(constant=0.9) oracle_map = DynamicLinearFunc(
function_param[2] = ComposedSinFunc( params={
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
) )
else: else:
raise ValueError("Unknown version: {:}".format(version)) raise ValueError("Unknown version: {:}".format(version))
function.set(function_param)
# dynamic_env.set_oracle_map(copy.deepcopy(function))
dynamic_env.set_oracle_map(function)
return dynamic_env return dynamic_env

View File

@ -1,15 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import math import math
import random import random
import numpy as np
from typing import List, Optional, Dict from typing import List, Optional, Dict
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from .synthetic_utils import TimeStamp
def is_list_tuple(x): def is_list_tuple(x):
return isinstance(x, (tuple, list)) return isinstance(x, (tuple, list))
@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset):
def __init__( def __init__(
self, self,
mean_functors: List[data.Dataset], data_generator,
cov_functors: List[List[data.Dataset]], oracle_map,
time_generator,
num_per_task: int = 5000, num_per_task: int = 5000,
timestamp_config: Optional[Dict] = None, noise: float = 0.1,
mode: Optional[str] = None,
timestamp_noise_scale: float = 0.3,
): ):
self._ndim = len(mean_functors) self._data_generator = data_generator
assert self._ndim == len( self._time_generator = time_generator
cov_functors self._oracle_map = oracle_map
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
for cov_functor in cov_functors:
assert self._ndim == len(
cov_functor
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
self._num_per_task = num_per_task self._num_per_task = num_per_task
if timestamp_config is None: self._noise = noise
timestamp_config = dict(mode=mode)
elif "mode" not in timestamp_config:
timestamp_config["mode"] = mode
self._timestamp_generator = TimeStamp(**timestamp_config)
self._timestamp_noise_scale = timestamp_noise_scale
self._mean_functors = mean_functors
self._cov_functors = cov_functors
self._oracle_map = None
@property @property
def min_timestamp(self): def min_timestamp(self):
return self._timestamp_generator.min_timestamp return self._time_generator.min_timestamp
@property @property
def max_timestamp(self): def max_timestamp(self):
return self._timestamp_generator.max_timestamp return self._time_generator.max_timestamp
@property @property
def timestamp_interval(self): def time_interval(self):
return self._timestamp_generator.interval return self._time_generator.interval
@property
def mode(self):
return self._time_generator.mode
def random_timestamp(self, min_timestamp=None, max_timestamp=None): def random_timestamp(self, min_timestamp=None, max_timestamp=None):
if min_timestamp is None: if min_timestamp is None:
@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset):
def get_timestamp(self, index): def get_timestamp(self, index):
if index is None: if index is None:
timestamps = [] timestamps = []
for index in range(len(self._timestamp_generator)): for index in range(len(self._time_generator)):
timestamps.append(self._timestamp_generator[index][1]) timestamps.append(self._time_generator[index][1])
return tuple(timestamps) return tuple(timestamps)
else: else:
index, timestamp = self._timestamp_generator[index] index, timestamp = self._time_generator[index]
return timestamp return timestamp
def set_oracle_map(self, functor):
self._oracle_map = functor
def __iter__(self): def __iter__(self):
self._iter_num = 0 self._iter_num = 0
return self return self
@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index, timestamp = self._timestamp_generator[index] index, timestamp = self._time_generator[index]
return self.__call__(timestamp) return self.__call__(timestamp)
def seq_call(self, timestamps): def seq_call(self, timestamps):
@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset):
return zip_sequence(xdata) return zip_sequence(xdata)
def __call__(self, timestamp): def __call__(self, timestamp):
mean_list = [functor(timestamp) for functor in self._mean_functors] dataset = self._data_generator(timestamp, self._num_per_task)
cov_matrix = [ targets = self._oracle_map.noise_call(dataset, timestamp, self._noise)
[abs(cov_gen(timestamp)) for cov_gen in cov_functor] return torch.Tensor([timestamp]), (
for cov_functor in self._cov_functors torch.Tensor(dataset),
] torch.Tensor(targets),
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
) )
if self._oracle_map is None:
return torch.Tensor([timestamp]), torch.Tensor(dataset)
else:
targets = self._oracle_map.noise_call(dataset, timestamp)
return torch.Tensor([timestamp]), (
torch.Tensor(dataset),
torch.Tensor(targets),
)
def __len__(self): def __len__(self):
return len(self._timestamp_generator) return len(self._time_generator)
def __repr__(self): def __repr__(self):
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format( return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
cur_num=len(self), cur_num=len(self),
total=len(self._timestamp_generator), total=len(self._time_generator),
ndim=self._ndim, ndim=self._ndim,
num_per_task=self._num_per_task, num_per_task=self._num_per_task,
xrange_min=self.min_timestamp, xrange_min=self.min_timestamp,
xrange_max=self.max_timestamp, xrange_max=self.max_timestamp,
mode=self._timestamp_generator.mode, mode=self.mode,
) )
class EnvSampler:
def __init__(self, env, batch, enlarge):
indexes = list(range(len(env)))
self._indexes = indexes * enlarge
self._batch = batch
self._iterations = len(self._indexes) // self._batch
def __iter__(self):
random.shuffle(self._indexes)
for it in range(self._iterations):
indexes = self._indexes[it * self._batch : (it + 1) * self._batch]
yield indexes
def __len__(self):
return self._iterations

View File

@ -1,72 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .synthetic_env import SyntheticDEnv
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
if indicator == "v1":
return create_example_v1(timestamp_config, num_per_task)
elif indicator == "v2":
return create_example_v2(timestamp_config, num_per_task)
else:
raise ValueError("Unkonwn indicator: {:}".format(indicator))
def create_example_v1(
timestamp_config=None,
num_per_task=5000,
):
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=timestamp_config,
)
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = ComposedSinFunc(
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function
def create_example_v2(
timestamp_config=None,
num_per_task=5000,
):
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
dynamic_env = SyntheticDEnv(
[mean_generator],
[[std_generator]],
num_per_task=num_per_task,
timestamp_config=timestamp_config,
)
function = DynamicLinearFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
)
function_param[1] = ConstantFunc(constant=0.9)
function.set(function_param)
dynamic_env.set_oracle_map(copy.deepcopy(function))
return dynamic_env, function

View File

@ -13,11 +13,11 @@ class UnifiedSplit:
"""A class to unify the split strategy.""" """A class to unify the split strategy."""
def __init__(self, total_num, mode): def __init__(self, total_num, mode):
# Training Set 60% # Training Set 65%
num_of_train = int(total_num * 0.6) num_of_train = int(total_num * 0.65)
# Validation Set 20% # Validation Set 05%
num_of_valid = int(total_num * 0.2) num_of_valid = int(total_num * 0.05)
# Test Set 20% # Test Set 30%
num_of_set = total_num - num_of_train - num_of_valid num_of_set = total_num - num_of_train - num_of_valid
all_indexes = list(range(total_num)) all_indexes = list(range(total_num))
if mode is None: if mode is None:
@ -28,6 +28,8 @@ class UnifiedSplit:
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid] self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
elif mode.lower() in ("test", "testing"): elif mode.lower() in ("test", "testing"):
self._indexes = all_indexes[num_of_train + num_of_valid :] self._indexes = all_indexes[num_of_train + num_of_valid :]
elif mode.lower() in ("trainval", "trainvalidation"):
self._indexes = all_indexes[: num_of_train + num_of_valid]
else: else:
raise ValueError("Unkonwn mode of {:}".format(mode)) raise ValueError("Unkonwn mode of {:}".format(mode))
self._all_indexes = all_indexes self._all_indexes = all_indexes