From 6da60664f5cd89211392ab16df199ada8ae9db47 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 15:44:01 +0800 Subject: [PATCH] Re-organize GeMOSA --- exps/GeMOSA/main.py | 73 +++++--- exps/GeMOSA/meta_model.py | 17 +- exps/GeMOSA/vis-synthetic.py | 2 + xautodl/datasets/math_adv_funcs.py | 92 ---------- xautodl/datasets/math_base_funcs.py | 192 ++------------------- xautodl/datasets/math_core.py | 16 +- xautodl/datasets/math_dynamic_funcs.py | 52 +++--- xautodl/datasets/math_static_funcs.py | 225 +++++++++++++++++++++++++ xautodl/datasets/synthetic_core.py | 31 ++-- xautodl/datasets/synthetic_env.py | 4 + 10 files changed, 354 insertions(+), 350 deletions(-) delete mode 100644 xautodl/datasets/math_adv_funcs.py create mode 100644 xautodl/datasets/math_static_funcs.py diff --git a/exps/GeMOSA/main.py b/exps/GeMOSA/main.py index 5a0b893..d7fe8a3 100644 --- a/exps/GeMOSA/main.py +++ b/exps/GeMOSA/main.py @@ -1,10 +1,9 @@ ##################################################### # Learning to Generate Model One Step Ahead # ##################################################### -# python exps/GeMOSA/lfna.py --env_version v1 --workers 0 -# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 -# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 -# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 +# python exps/GeMOSA/main.py --env_version v1 --workers 0 +# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256 +# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData from meta_model import MetaModelV1 -def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): +def online_evaluate( + env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False +): logger.log("Online evaluate: {:}".format(env)) loss_meter = AverageMeter() w_containers = dict() @@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F with torch.no_grad(): meta_model.eval() base_model.eval() - [future_container], time_embeds = meta_model( - future_time.to(args.device).view(-1), None, False + future_time_embed = meta_model.gen_time_embed( + future_time.to(args.device).view(-1) ) + [future_container] = meta_model.gen_model(future_time_embed) if save: w_containers[idx] = future_container.no_grad_clone() future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = base_model.forward_with_container(future_x, future_container) future_loss = criterion(future_y_hat, future_y) loss_meter.update(future_loss.item()) - refine, post_refine_loss = meta_model.adapt( - base_model, - criterion, - future_time.item(), - future_x, - future_y, - args.refine_lr, - args.refine_epochs, - {"param": time_embeds, "loss": future_loss.item()}, - ) + if easy_adapt: + meta_model.easy_adapt(future_time.item(), future_time_embed) + refine, post_refine_loss = False, -1 + else: + refine, post_refine_loss = meta_model.adapt( + base_model, + criterion, + future_time.item(), + future_x, + future_y, + args.refine_lr, + args.refine_epochs, + {"param": future_time_embed, "loss": future_loss.item()}, + ) logger.log( "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( idx, len(env), future_loss.item() @@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): ) optimizer.zero_grad() - generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) + generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps) batch_indexes = random.choices(total_indexes, k=args.meta_batch) @@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): ) # future loss total_future_losses, total_present_losses = [], [] - future_containers, _ = meta_model( - None, generated_time_embeds[batch_indexes], False - ) - present_containers, _ = meta_model( - None, meta_model.super_meta_embed[batch_indexes], False + future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes]) + present_containers = meta_model.gen_model( + meta_model.super_meta_embed[batch_indexes] ) for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): _, (inputs, targets) = xenv(time_step) @@ -216,13 +220,34 @@ def main(args): # try to evaluate once # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) + """ w_containers, loss_meter = online_evaluate( all_env, meta_model, base_model, criterion, args, logger, True ) logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) + """ + _, test_loss_meter_adapt_v1 = online_evaluate( + valid_env, meta_model, base_model, criterion, args, logger, False, False + ) + _, test_loss_meter_adapt_v2 = online_evaluate( + valid_env, meta_model, base_model, criterion, args, logger, False, True + ) + logger.log( + "In the online test enviornment, the total loss for refine-adapt is {:}".format( + test_loss_meter_adapt_v1 + ) + ) + logger.log( + "In the online test enviornment, the total loss for easy-adapt is {:}".format( + test_loss_meter_adapt_v2 + ) + ) save_checkpoint( - {"all_w_containers": w_containers}, + { + "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, + "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, + }, logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), logger, ) diff --git a/exps/GeMOSA/meta_model.py b/exps/GeMOSA/meta_model.py index 79f9fb9..0d2e1bf 100644 --- a/exps/GeMOSA/meta_model.py +++ b/exps/GeMOSA/meta_model.py @@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule): batch_containers.append( self._shape_container.translate(torch.split(weights.squeeze(0), 1)) ) - return batch_containers, time_embeds + return batch_containers def forward_raw(self, timestamps, time_embeds, tembed_only=False): raise NotImplementedError @@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule): def forward_candidate(self, input): raise NotImplementedError + def easy_adapt(self, timestamp, time_embed): + with torch.no_grad(): + timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) + self.replace_append_learnt(None, None) + self.append_fixed(timestamp, time_embed) + def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): distance = self.get_closest_meta_distance(timestamp) if distance + self._interval * 1e-2 <= self._interval: @@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule): best_new_param = new_param.detach().clone() for iepoch in range(epochs): optimizer.zero_grad() - _, time_embed = self(timestamp.view(1), None) + time_embed = self.gen_time_embed(timestamp.view(1)) match_loss = criterion(new_param, time_embed) - [container], time_embed = self(None, new_param.view(1, -1)) + [container] = self.gen_model(new_param.view(1, -1)) y_hat = base_model.forward_with_container(x, container) meta_loss = criterion(y_hat, y) loss = meta_loss + match_loss loss.backward() optimizer.step() - # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) if meta_loss.item() < best_loss: with torch.no_grad(): best_loss = meta_loss.item() best_new_param = new_param.detach().clone() - with torch.no_grad(): - self.replace_append_learnt(None, None) - self.append_fixed(timestamp, best_new_param) + self.easy_adapt(timestamp, best_new_param) return True, best_loss def extra_repr(self) -> str: diff --git a/exps/GeMOSA/vis-synthetic.py b/exps/GeMOSA/vis-synthetic.py index a3f2885..8666851 100644 --- a/exps/GeMOSA/vis-synthetic.py +++ b/exps/GeMOSA/vis-synthetic.py @@ -191,6 +191,8 @@ def visualize_env(save_dir, version): allxs.append(allx) allys.append(ally) allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) + print("env: {:}".format(dynamic_env)) + print("oracle_map: {:}".format(dynamic_env.oracle_map)) print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): diff --git a/xautodl/datasets/math_adv_funcs.py b/xautodl/datasets/math_adv_funcs.py deleted file mode 100644 index f180078..0000000 --- a/xautodl/datasets/math_adv_funcs.py +++ /dev/null @@ -1,92 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### -import math -import abc -import copy -import numpy as np -from typing import Optional -import torch -import torch.utils.data as data - -from .math_base_funcs import FitFunc -from .math_base_funcs import QuadraticFunc -from .math_base_funcs import QuarticFunc - - -class ConstantFunc(FitFunc): - """The constant function: f(x) = c.""" - - def __init__(self, constant=None, xstr="x"): - param = dict() - param[0] = constant - super(ConstantFunc, self).__init__(0, None, param, xstr) - - def __call__(self, x): - self.check_valid() - return self._params[0] - - def fit(self, **kwargs): - raise NotImplementedError - - def _getitem(self, x, weights): - raise NotImplementedError - - def __repr__(self): - return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) - - -class ComposedSinFunc(FitFunc): - """The composed sin function that outputs: - f(x) = a * sin( b*x ) + c - """ - - def __init__(self, params, xstr="x"): - super(ComposedSinFunc, self).__init__(3, None, params, xstr) - - def __call__(self, x): - self.check_valid() - a = self._params[0] - b = self._params[1] - c = self._params[2] - return a * math.sin(b * x) + c - - def _getitem(self, x, weights): - raise NotImplementedError - - def __repr__(self): - return "{name}({a} * sin({b} * {x}) + {c})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - x=self.xstr, - ) - - -class ComposedCosFunc(FitFunc): - """The composed sin function that outputs: - f(x) = a * cos( b*x ) + c - """ - - def __init__(self, params, xstr="x"): - super(ComposedCosFunc, self).__init__(3, None, params, xstr) - - def __call__(self, x): - self.check_valid() - a = self._params[0] - b = self._params[1] - c = self._params[2] - return a * math.cos(b * x) + c - - def _getitem(self, x, weights): - raise NotImplementedError - - def __repr__(self): - return "{name}({a} * sin({b} * {x}) + {c})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - x=self.xstr, - ) diff --git a/xautodl/datasets/math_base_funcs.py b/xautodl/datasets/math_base_funcs.py index 485281d..8bc48de 100644 --- a/xautodl/datasets/math_base_funcs.py +++ b/xautodl/datasets/math_base_funcs.py @@ -5,34 +5,33 @@ import math import abc import copy import numpy as np -import torch -class FitFunc(abc.ABC): - """The fit function that outputs f(x) = a * x^2 + b * x + c.""" +class MathFunc(abc.ABC): + """The math function -- a virtual class defining some APIs.""" - def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): + def __init__(self, freedom: int, params=None, xstr="x"): + # initialize as empty self._params = dict() for i in range(freedom): self._params[i] = None self._freedom = freedom - if list_of_points is not None and params is not None: - raise ValueError("list_of_points and params can not be set simultaneously") - if list_of_points is not None: - self.fit(list_of_points=list_of_points) if params is not None: self.set(params) self._xstr = str(xstr) + self._skip_check = True def set(self, params): - self._params = copy.deepcopy(params) + for key in range(self._freedom): + param = copy.deepcopy(params[key]) + self._params[key] = param def check_valid(self): - # for key, value in self._params.items(): - for key in range(self._freedom): - value = self._params[key] - if value is None: - raise ValueError("The {:} is None".format(key)) + if not self._skip_check: + for key in range(self._freedom): + value = self._params[key] + if value is None: + raise ValueError("The {:} is None".format(key)) @property def xstr(self): @@ -45,7 +44,8 @@ class FitFunc(abc.ABC): def __call__(self, x): raise NotImplementedError - def noise_call(self, x, std=0.1): + @abc.abstractmethod + def noise_call(self, x, std): clean_y = self.__call__(x) if isinstance(clean_y, np.ndarray): noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) @@ -53,169 +53,7 @@ class FitFunc(abc.ABC): raise ValueError("Unkonwn type: {:}".format(type(clean_y))) return noise_y - @abc.abstractmethod - def _getitem(self, x): - raise NotImplementedError - - def fit(self, **kwargs): - list_of_points = kwargs["list_of_points"] - max_iter, lr_max, verbose = ( - kwargs.get("max_iter", 900), - kwargs.get("lr_max", 1.0), - kwargs.get("verbose", False), - ) - with torch.no_grad(): - data = torch.Tensor(list_of_points).type(torch.float32) - assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( - data.shape - ) - x, y = data[:, 0], data[:, 1] - weights = torch.nn.Parameter(torch.Tensor(self._freedom)) - torch.nn.init.normal_(weights, mean=0.0, std=1.0) - optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[ - int(max_iter * 0.25), - int(max_iter * 0.5), - int(max_iter * 0.75), - ], - gamma=0.1, - ) - if verbose: - print("The optimizer: {:}".format(optimizer)) - - best_loss = None - for _iter in range(max_iter): - y_hat = self._getitem(x, weights) - loss = torch.mean(torch.abs(y - y_hat)) - optimizer.zero_grad() - loss.backward() - optimizer.step() - lr_scheduler.step() - if verbose: - print( - "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( - _iter, max_iter, loss.item() - ) - ) - # Update the params - if best_loss is None or best_loss > loss.item(): - best_loss = loss.item() - for i in range(self._freedom): - self._params[i] = weights[i].item() - def __repr__(self): return "{name}(freedom={freedom})".format( name=self.__class__.__name__, freedom=freedom ) - - -class LinearFunc(FitFunc): - """The linear function that outputs f(x) = a * x + b.""" - - def __init__(self, list_of_points=None, params=None, xstr="x"): - super(LinearFunc, self).__init__(2, list_of_points, params, xstr) - - def __call__(self, x): - self.check_valid() - return self._params[0] * x + self._params[1] - - def _getitem(self, x, weights): - return weights[0] * x + weights[1] - - def __repr__(self): - return "{name}({a} * {x} + {b})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - x=self.xstr, - ) - - -class QuadraticFunc(FitFunc): - """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" - - def __init__(self, list_of_points=None, params=None, xstr="x"): - super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) - - def __call__(self, x): - self.check_valid() - return self._params[0] * x * x + self._params[1] * x + self._params[2] - - def _getitem(self, x, weights): - return weights[0] * x * x + weights[1] * x + weights[2] - - def __repr__(self): - return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - x=self.xstr, - ) - - -class CubicFunc(FitFunc): - """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" - - def __init__(self, list_of_points=None): - super(CubicFunc, self).__init__(4, list_of_points) - - def __call__(self, x): - self.check_valid() - return ( - self._params[0] * x ** 3 - + self._params[1] * x ** 2 - + self._params[2] * x - + self._params[3] - ) - - def _getitem(self, x, weights): - return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] - - def __repr__(self): - return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - d=self._params[3], - x=self.xstr, - ) - - -class QuarticFunc(FitFunc): - """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" - - def __init__(self, list_of_points=None): - super(QuarticFunc, self).__init__(5, list_of_points) - - def __call__(self, x): - self.check_valid() - return ( - self._params[0] * x ** 4 - + self._params[1] * x ** 3 - + self._params[2] * x ** 2 - + self._params[3] * x - + self._params[4] - ) - - def _getitem(self, x, weights): - return ( - weights[0] * x ** 4 - + weights[1] * x ** 3 - + weights[2] * x ** 2 - + weights[3] * x - + weights[4] - ) - - def __repr__(self): - return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - d=self._params[3], - e=self._params[3], - ) diff --git a/xautodl/datasets/math_core.py b/xautodl/datasets/math_core.py index 9e8929c..7cdda0d 100644 --- a/xautodl/datasets/math_core.py +++ b/xautodl/datasets/math_core.py @@ -1,10 +1,14 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # ##################################################### -from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc -from .math_dynamic_funcs import DynamicLinearFunc -from .math_dynamic_funcs import DynamicQuadraticFunc -from .math_dynamic_funcs import DynamicSinQuadraticFunc -from .math_adv_funcs import ConstantFunc -from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc +from .math_static_funcs import ( + LinearSFunc, + QuadraticSFunc, + CubicSFunc, + QuarticSFunc, + ConstantFunc, + ComposedSinSFunc, + ComposedCosSFunc, +) +from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc from .math_dynamic_generator import GaussianDGenerator diff --git a/xautodl/datasets/math_dynamic_funcs.py b/xautodl/datasets/math_dynamic_funcs.py index f475ff8..900a6c9 100644 --- a/xautodl/datasets/math_dynamic_funcs.py +++ b/xautodl/datasets/math_dynamic_funcs.py @@ -6,23 +6,17 @@ import abc import copy import numpy as np -from .math_base_funcs import FitFunc +from .math_base_funcs import MathFunc -class DynamicFunc(FitFunc): - """The dynamic quadratic function, where each param is a function.""" +class DynamicFunc(MathFunc): + """The dynamic function, where each param is a function.""" def __init__(self, freedom: int, params=None, xstr="x"): if params is not None: - for param in params: - param.reset_xstr("t") if isinstance(param, FitFunc) else None - super(DynamicFunc, self).__init__(freedom, None, params, xstr) - - def __call__(self, x, timestamp): - raise NotImplementedError - - def _getitem(self, x, weights): - raise NotImplementedError + for key, param in params.items(): + param.reset_xstr("t") if isinstance(param, MathFunc) else None + super(DynamicFunc, self).__init__(freedom, params, xstr) def noise_call(self, x, timestamp, std): clean_y = self.__call__(x, timestamp) @@ -33,13 +27,13 @@ class DynamicFunc(FitFunc): return noise_y -class DynamicLinearFunc(DynamicFunc): +class LinearDFunc(DynamicFunc): """The dynamic linear function that outputs f(x) = a * x + b. The a and b is a function of timestamp. """ - def __init__(self, params=None, xstr="x"): - super(DynamicLinearFunc, self).__init__(3, params, xstr) + def __init__(self, params, xstr="x"): + super(LinearDFunc, self).__init__(2, params, xstr) def __call__(self, x, timestamp): a = self._params[0](timestamp) @@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc): ) -class DynamicQuadraticFunc(DynamicFunc): +class QuadraticDFunc(DynamicFunc): """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. The a, b, and c is a function of timestamp. """ - def __init__(self, params=None): - super(DynamicQuadraticFunc, self).__init__(3, params) + def __init__(self, params, xstr="x"): + super(QuadraticDFunc, self).__init__(3, params) - def __call__( - self, - x, - ): + def __call__(self, x, timestamp): self.check_valid() a = self._params[0](timestamp) b = self._params[1](timestamp) @@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc): return a * x * x + b * x + c def __repr__(self): - return "{name}({a} * x^2 + {b} * x + {c})".format( + return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], c=self._params[2], + x=self.xstr, ) -class DynamicSinQuadraticFunc(DynamicFunc): +class SinQuadraticDFunc(DynamicFunc): """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). The a, b, and c is a function of timestamp. """ def __init__(self, params=None): - super(DynamicSinQuadraticFunc, self).__init__(3, params) + super(SinQuadraticDFunc, self).__init__(3, params) - def __call__( - self, - x, - ): + def __call__(self, x, timestamp): self.check_valid() a = self._params[0](timestamp) b = self._params[1](timestamp) c = self._params[2](timestamp) convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) - return math.sin(a * x * x + b * x + c) + return np.sin(a * x * x + b * x + c) def __repr__(self): - return "{name}({a} * x^2 + {b} * x + {c})".format( + return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], c=self._params[2], + x=self.xstr, ) diff --git a/xautodl/datasets/math_static_funcs.py b/xautodl/datasets/math_static_funcs.py new file mode 100644 index 0000000..230ccd6 --- /dev/null +++ b/xautodl/datasets/math_static_funcs.py @@ -0,0 +1,225 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import math +import abc +import copy +import numpy as np + +from .math_base_funcs import MathFunc + + +class StaticFunc(MathFunc): + """The fit function that outputs f(x) = a * x^2 + b * x + c.""" + + def __init__(self, freedom: int, params=None, xstr="x"): + super(StaticFunc, self).__init__(freedom, params, xstr) + + @abc.abstractmethod + def __call__(self, x): + raise NotImplementedError + + def noise_call(self, x, std): + clean_y = self.__call__(x) + if isinstance(clean_y, np.ndarray): + noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) + else: + raise ValueError("Unkonwn type: {:}".format(type(clean_y))) + return noise_y + + def __repr__(self): + return "{name}(freedom={freedom})".format( + name=self.__class__.__name__, freedom=freedom + ) + + +class LinearSFunc(StaticFunc): + """The linear function that outputs f(x) = a * x + b.""" + + def __init__(self, params=None, xstr="x"): + super(LinearSFunc, self).__init__(2, params, xstr) + + def __call__(self, x): + self.check_valid() + return self._params[0] * x + self._params[1] + + def _getitem(self, x, weights): + return weights[0] * x + weights[1] + + def __repr__(self): + return "{name}({a} * {x} + {b})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + x=self.xstr, + ) + + +class QuadraticSFunc(StaticFunc): + """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" + + def __init__(self, params=None, xstr="x"): + super(QuadraticSFunc, self).__init__(3, params, xstr) + + def __call__(self, x): + self.check_valid() + return self._params[0] * x * x + self._params[1] * x + self._params[2] + + def _getitem(self, x, weights): + return weights[0] * x * x + weights[1] * x + weights[2] + + def __repr__(self): + return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + x=self.xstr, + ) + + +class CubicSFunc(StaticFunc): + """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" + + def __init__(self, params=None, xstr="x"): + super(CubicSFunc, self).__init__(4, params, xstr) + + def __call__(self, x): + self.check_valid() + return ( + self._params[0] * x ** 3 + + self._params[1] * x ** 2 + + self._params[2] * x + + self._params[3] + ) + + def _getitem(self, x, weights): + return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] + + def __repr__(self): + return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + d=self._params[3], + x=self.xstr, + ) + + +class QuarticSFunc(StaticFunc): + """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" + + def __init__(self, params=None, xstr="x"): + super(QuarticSFunc, self).__init__(5, params, xstr) + + def __call__(self, x): + self.check_valid() + return ( + self._params[0] * x ** 4 + + self._params[1] * x ** 3 + + self._params[2] * x ** 2 + + self._params[3] * x + + self._params[4] + ) + + def _getitem(self, x, weights): + return ( + weights[0] * x ** 4 + + weights[1] * x ** 3 + + weights[2] * x ** 2 + + weights[3] * x + + weights[4] + ) + + def __repr__(self): + return ( + "{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + d=self._params[3], + e=self._params[3], + x=self.xstr, + ) + ) + + +### advanced functions + + +class ConstantFunc(StaticFunc): + """The constant function: f(x) = c.""" + + def __init__(self, constant, xstr="x"): + super(ConstantFunc, self).__init__(1, {0: constant}, xstr) + + def __call__(self, x): + self.check_valid() + return self._params[0] + + def fit(self, **kwargs): + raise NotImplementedError + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) + + +class ComposedSinSFunc(StaticFunc): + """The composed sin function that outputs: + f(x) = a * sin( b*x ) + c + """ + + def __init__(self, params, xstr="x"): + super(ComposedSinSFunc, self).__init__(3, params, xstr) + + def __call__(self, x): + self.check_valid() + a = self._params[0] + b = self._params[1] + c = self._params[2] + return a * math.sin(b * x) + c + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({a} * sin({b} * {x}) + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + x=self.xstr, + ) + + +class ComposedCosSFunc(StaticFunc): + """The composed sin function that outputs: + f(x) = a * cos( b*x ) + c + """ + + def __init__(self, params, xstr="x"): + super(ComposedCosSFunc, self).__init__(3, params, xstr) + + def __call__(self, x): + self.check_valid() + a = self._params[0] + b = self._params[1] + c = self._params[2] + return a * math.cos(b * x) + c + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({a} * sin({b} * {x}) + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + x=self.xstr, + ) diff --git a/xautodl/datasets/synthetic_core.py b/xautodl/datasets/synthetic_core.py index c5b2da0..5df6fe8 100644 --- a/xautodl/datasets/synthetic_core.py +++ b/xautodl/datasets/synthetic_core.py @@ -1,13 +1,13 @@ import math from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv -from .math_core import LinearFunc -from .math_core import DynamicLinearFunc -from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc +from .math_core import LinearSFunc +from .math_core import LinearDFunc +from .math_core import QuadraticDFunc, SinQuadraticDFunc from .math_core import ( ConstantFunc, - ComposedSinFunc as SinFunc, - ComposedCosFunc as CosFunc, + ComposedSinSFunc as SinFunc, + ComposedCosSFunc as CosFunc, ) from .math_core import GaussianDGenerator @@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): max_time = math.pi * 10 - if version == "v1": + if version.lower() == "v1": mean_generator = ConstantFunc(0) std_generator = ConstantFunc(1) data_generator = GaussianDGenerator( @@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio time_generator = TimeStamp( min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) - oracle_map = DynamicLinearFunc( + oracle_map = LinearDFunc( params={ 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8 @@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio dynamic_env = SyntheticDEnv( data_generator, oracle_map, time_generator, num_per_task ) - elif version == "v2": + dynamic_env.set_regression() + elif version.lower() == "v2": mean_generator = ConstantFunc(0) std_generator = ConstantFunc(1) data_generator = GaussianDGenerator( @@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio time_generator = TimeStamp( min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) - oracle_map = DynamicQuadraticFunc( + oracle_map = QuadraticDFunc( params={ - 0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t - 1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t) - 2: ConstantFunc(0), + 0: LinearSFunc(params={0: 0.1, 1: 0}), # 0.1 * t + 1: ConstantFunc(0), + 2: CosFunc(params={0: 4.0, 1: 10, 2: 0}), # 4 * cos(10 * t) } ) dynamic_env = SyntheticDEnv( data_generator, oracle_map, time_generator, num_per_task ) + dynamic_env.set_regression() elif version.lower() == "v3": mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t) std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1 @@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio time_generator = TimeStamp( min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode ) - oracle_map = DynamicSinQuadraticFunc( + oracle_map = SinQuadraticDFunc( params={ 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1 1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t) @@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio dynamic_env = SyntheticDEnv( data_generator, oracle_map, time_generator, num_per_task ) + dynamic_env.set_regression() + elif version.lower() == "v4": + dynamic_env.set_classification(2) else: raise ValueError("Unknown version: {:}".format(version)) return dynamic_env diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index a434018..65e274e 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset): self._meta_info["task"] = "classification" self._meta_info["num_classes"] = int(num_classes) + @property + def oracle_map(self): + return self._oracle_map + @property def meta_info(self): return self._meta_info