Re-organize GeMOSA
This commit is contained in:
		| @@ -1,10 +1,9 @@ | ||||
| ##################################################### | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | ||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from meta_model import MetaModelV1 | ||||
|  | ||||
|  | ||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||
| def online_evaluate( | ||||
|     env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False | ||||
| ): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     loss_meter = AverageMeter() | ||||
|     w_containers = dict() | ||||
| @@ -46,15 +47,20 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             [future_container], time_embeds = meta_model( | ||||
|                 future_time.to(args.device).view(-1), None, False | ||||
|             future_time_embed = meta_model.gen_time_embed( | ||||
|                 future_time.to(args.device).view(-1) | ||||
|             ) | ||||
|             [future_container] = meta_model.gen_model(future_time_embed) | ||||
|             if save: | ||||
|                 w_containers[idx] = future_container.no_grad_clone() | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             loss_meter.update(future_loss.item()) | ||||
|         if easy_adapt: | ||||
|             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||
|             refine, post_refine_loss = False, -1 | ||||
|         else: | ||||
|             refine, post_refine_loss = meta_model.adapt( | ||||
|                 base_model, | ||||
|                 criterion, | ||||
| @@ -63,7 +69,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | ||||
|                 future_y, | ||||
|                 args.refine_lr, | ||||
|                 args.refine_epochs, | ||||
|             {"param": time_embeds, "loss": future_loss.item()}, | ||||
|                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||
|             ) | ||||
|         logger.log( | ||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
| @@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|         generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) | ||||
|         generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps) | ||||
|  | ||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||
|  | ||||
| @@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | ||||
|         ) | ||||
|         # future loss | ||||
|         total_future_losses, total_present_losses = [], [] | ||||
|         future_containers, _ = meta_model( | ||||
|             None, generated_time_embeds[batch_indexes], False | ||||
|         ) | ||||
|         present_containers, _ = meta_model( | ||||
|             None, meta_model.super_meta_embed[batch_indexes], False | ||||
|         future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes]) | ||||
|         present_containers = meta_model.gen_model( | ||||
|             meta_model.super_meta_embed[batch_indexes] | ||||
|         ) | ||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||
|             _, (inputs, targets) = xenv(time_step) | ||||
| @@ -216,13 +220,34 @@ def main(args): | ||||
|     # try to evaluate once | ||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||
|     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||
|     """ | ||||
|     w_containers, loss_meter = online_evaluate( | ||||
|         all_env, meta_model, base_model, criterion, args, logger, True | ||||
|     ) | ||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||
|     """ | ||||
|     _, test_loss_meter_adapt_v1 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, False | ||||
|     ) | ||||
|     _, test_loss_meter_adapt_v2 = online_evaluate( | ||||
|         valid_env, meta_model, base_model, criterion, args, logger, False, True | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for refine-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v1 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "In the online test enviornment, the total loss for easy-adapt is {:}".format( | ||||
|             test_loss_meter_adapt_v2 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"all_w_containers": w_containers}, | ||||
|         { | ||||
|             "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, | ||||
|             "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, | ||||
|         }, | ||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||
|         logger, | ||||
|     ) | ||||
|   | ||||
| @@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule): | ||||
|             batch_containers.append( | ||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||
|             ) | ||||
|         return batch_containers, time_embeds | ||||
|         return batch_containers | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         raise NotImplementedError | ||||
| @@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule): | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def easy_adapt(self, timestamp, time_embed): | ||||
|         with torch.no_grad(): | ||||
|             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, time_embed) | ||||
|  | ||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||
|         distance = self.get_closest_meta_distance(timestamp) | ||||
|         if distance + self._interval * 1e-2 <= self._interval: | ||||
| @@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule): | ||||
|                 best_new_param = new_param.detach().clone() | ||||
|             for iepoch in range(epochs): | ||||
|                 optimizer.zero_grad() | ||||
|                 _, time_embed = self(timestamp.view(1), None) | ||||
|                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||
|                 match_loss = criterion(new_param, time_embed) | ||||
|  | ||||
|                 [container], time_embed = self(None, new_param.view(1, -1)) | ||||
|                 [container] = self.gen_model(new_param.view(1, -1)) | ||||
|                 y_hat = base_model.forward_with_container(x, container) | ||||
|                 meta_loss = criterion(y_hat, y) | ||||
|                 loss = meta_loss + match_loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) | ||||
|                 if meta_loss.item() < best_loss: | ||||
|                     with torch.no_grad(): | ||||
|                         best_loss = meta_loss.item() | ||||
|                         best_new_param = new_param.detach().clone() | ||||
|         with torch.no_grad(): | ||||
|             self.replace_append_learnt(None, None) | ||||
|             self.append_fixed(timestamp, best_new_param) | ||||
|         self.easy_adapt(timestamp, best_new_param) | ||||
|         return True, best_loss | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|   | ||||
| @@ -191,6 +191,8 @@ def visualize_env(save_dir, version): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|   | ||||
| @@ -1,92 +0,0 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
| from .math_base_funcs import FitFunc | ||||
| from .math_base_funcs import QuadraticFunc | ||||
| from .math_base_funcs import QuarticFunc | ||||
|  | ||||
|  | ||||
| class ConstantFunc(FitFunc): | ||||
|     """The constant function: f(x) = c.""" | ||||
|  | ||||
|     def __init__(self, constant=None, xstr="x"): | ||||
|         param = dict() | ||||
|         param[0] = constant | ||||
|         super(ConstantFunc, self).__init__(0, None, param, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) | ||||
|  | ||||
|  | ||||
| class ComposedSinFunc(FitFunc): | ||||
|     """The composed sin function that outputs: | ||||
|     f(x) = a * sin( b*x ) + c | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(ComposedSinFunc, self).__init__(3, None, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         a = self._params[0] | ||||
|         b = self._params[1] | ||||
|         c = self._params[2] | ||||
|         return a * math.sin(b * x) + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ComposedCosFunc(FitFunc): | ||||
|     """The composed sin function that outputs: | ||||
|     f(x) = a * cos( b*x ) + c | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(ComposedCosFunc, self).__init__(3, None, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         a = self._params[0] | ||||
|         b = self._params[1] | ||||
|         c = self._params[2] | ||||
|         return a * math.cos(b * x) + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
| @@ -5,30 +5,29 @@ import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class FitFunc(abc.ABC): | ||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
| class MathFunc(abc.ABC): | ||||
|     """The math function -- a virtual class defining some APIs.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): | ||||
|     def __init__(self, freedom: int, params=None, xstr="x"): | ||||
|         # initialize as empty | ||||
|         self._params = dict() | ||||
|         for i in range(freedom): | ||||
|             self._params[i] = None | ||||
|         self._freedom = freedom | ||||
|         if list_of_points is not None and params is not None: | ||||
|             raise ValueError("list_of_points and params can not be set simultaneously") | ||||
|         if list_of_points is not None: | ||||
|             self.fit(list_of_points=list_of_points) | ||||
|         if params is not None: | ||||
|             self.set(params) | ||||
|         self._xstr = str(xstr) | ||||
|         self._skip_check = True | ||||
|  | ||||
|     def set(self, params): | ||||
|         self._params = copy.deepcopy(params) | ||||
|         for key in range(self._freedom): | ||||
|             param = copy.deepcopy(params[key]) | ||||
|             self._params[key] = param | ||||
|  | ||||
|     def check_valid(self): | ||||
|         # for key, value in self._params.items(): | ||||
|         if not self._skip_check: | ||||
|             for key in range(self._freedom): | ||||
|                 value = self._params[key] | ||||
|                 if value is None: | ||||
| @@ -45,7 +44,8 @@ class FitFunc(abc.ABC): | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def noise_call(self, x, std=0.1): | ||||
|     @abc.abstractmethod | ||||
|     def noise_call(self, x, std): | ||||
|         clean_y = self.__call__(x) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
| @@ -53,169 +53,7 @@ class FitFunc(abc.ABC): | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
|         return noise_y | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _getitem(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         list_of_points = kwargs["list_of_points"] | ||||
|         max_iter, lr_max, verbose = ( | ||||
|             kwargs.get("max_iter", 900), | ||||
|             kwargs.get("lr_max", 1.0), | ||||
|             kwargs.get("verbose", False), | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             data = torch.Tensor(list_of_points).type(torch.float32) | ||||
|             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( | ||||
|                 data.shape | ||||
|             ) | ||||
|             x, y = data[:, 0], data[:, 1] | ||||
|         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) | ||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(max_iter * 0.25), | ||||
|                 int(max_iter * 0.5), | ||||
|                 int(max_iter * 0.75), | ||||
|             ], | ||||
|             gamma=0.1, | ||||
|         ) | ||||
|         if verbose: | ||||
|             print("The optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|         best_loss = None | ||||
|         for _iter in range(max_iter): | ||||
|             y_hat = self._getitem(x, weights) | ||||
|             loss = torch.mean(torch.abs(y - y_hat)) | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             if verbose: | ||||
|                 print( | ||||
|                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||
|                         _iter, max_iter, loss.item() | ||||
|                     ) | ||||
|                 ) | ||||
|             # Update the params | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 for i in range(self._freedom): | ||||
|                     self._params[i] = weights[i].item() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(freedom={freedom})".format( | ||||
|             name=self.__class__.__name__, freedom=freedom | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LinearFunc(FitFunc): | ||||
|     """The linear function that outputs f(x) = a * x + b.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||
|         super(LinearFunc, self).__init__(2, list_of_points, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x + self._params[1] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticFunc(FitFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CubicFunc(FitFunc): | ||||
|     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(CubicFunc, self).__init__(4, list_of_points) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
|             + self._params[1] * x ** 2 | ||||
|             + self._params[2] * x | ||||
|             + self._params[3] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuarticFunc(FitFunc): | ||||
|     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuarticFunc, self).__init__(5, list_of_points) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
|             + self._params[1] * x ** 3 | ||||
|             + self._params[2] * x ** 2 | ||||
|             + self._params[3] * x | ||||
|             + self._params[4] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return ( | ||||
|             weights[0] * x ** 4 | ||||
|             + weights[1] * x ** 3 | ||||
|             + weights[2] * x ** 2 | ||||
|             + weights[3] * x | ||||
|             + weights[4] | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             e=self._params[3], | ||||
|         ) | ||||
|   | ||||
| @@ -1,10 +1,14 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_dynamic_funcs import DynamicSinQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc | ||||
| from .math_static_funcs import ( | ||||
|     LinearSFunc, | ||||
|     QuadraticSFunc, | ||||
|     CubicSFunc, | ||||
|     QuarticSFunc, | ||||
|     ConstantFunc, | ||||
|     ComposedSinSFunc, | ||||
|     ComposedCosSFunc, | ||||
| ) | ||||
| from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc | ||||
| from .math_dynamic_generator import GaussianDGenerator | ||||
|   | ||||
| @@ -6,23 +6,17 @@ import abc | ||||
| import copy | ||||
| import numpy as np | ||||
|  | ||||
| from .math_base_funcs import FitFunc | ||||
| from .math_base_funcs import MathFunc | ||||
|  | ||||
|  | ||||
| class DynamicFunc(FitFunc): | ||||
|     """The dynamic quadratic function, where each param is a function.""" | ||||
| class DynamicFunc(MathFunc): | ||||
|     """The dynamic function, where each param is a function.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, params=None, xstr="x"): | ||||
|         if params is not None: | ||||
|             for param in params: | ||||
|                 param.reset_xstr("t") if isinstance(param, FitFunc) else None | ||||
|         super(DynamicFunc, self).__init__(freedom, None, params, xstr) | ||||
|  | ||||
|     def __call__(self, x, timestamp): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|             for key, param in params.items(): | ||||
|                 param.reset_xstr("t") if isinstance(param, MathFunc) else None | ||||
|         super(DynamicFunc, self).__init__(freedom, params, xstr) | ||||
|  | ||||
|     def noise_call(self, x, timestamp, std): | ||||
|         clean_y = self.__call__(x, timestamp) | ||||
| @@ -33,13 +27,13 @@ class DynamicFunc(FitFunc): | ||||
|         return noise_y | ||||
|  | ||||
|  | ||||
| class DynamicLinearFunc(DynamicFunc): | ||||
| class LinearDFunc(DynamicFunc): | ||||
|     """The dynamic linear function that outputs f(x) = a * x + b. | ||||
|     The a and b is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(DynamicLinearFunc, self).__init__(3, params, xstr) | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(LinearDFunc, self).__init__(2, params, xstr) | ||||
|  | ||||
|     def __call__(self, x, timestamp): | ||||
|         a = self._params[0](timestamp) | ||||
| @@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DynamicQuadraticFunc(DynamicFunc): | ||||
| class QuadraticDFunc(DynamicFunc): | ||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||
|     The a, b, and c is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicQuadraticFunc, self).__init__(3, params) | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(QuadraticDFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x, | ||||
|     ): | ||||
|     def __call__(self, x, timestamp): | ||||
|         self.check_valid() | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
| @@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc): | ||||
|         return a * x * x + b * x + c | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DynamicSinQuadraticFunc(DynamicFunc): | ||||
| class SinQuadraticDFunc(DynamicFunc): | ||||
|     """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). | ||||
|     The a, b, and c is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicSinQuadraticFunc, self).__init__(3, params) | ||||
|         super(SinQuadraticDFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x, | ||||
|     ): | ||||
|     def __call__(self, x, timestamp): | ||||
|         self.check_valid() | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         c = self._params[2](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||
|         return math.sin(a * x * x + b * x + c) | ||||
|         return np.sin(a * x * x + b * x + c) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|   | ||||
							
								
								
									
										225
									
								
								xautodl/datasets/math_static_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								xautodl/datasets/math_static_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import copy | ||||
| import numpy as np | ||||
|  | ||||
| from .math_base_funcs import MathFunc | ||||
|  | ||||
|  | ||||
| class StaticFunc(MathFunc): | ||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, freedom: int, params=None, xstr="x"): | ||||
|         super(StaticFunc, self).__init__(freedom, params, xstr) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def noise_call(self, x, std): | ||||
|         clean_y = self.__call__(x) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
|         else: | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
|         return noise_y | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(freedom={freedom})".format( | ||||
|             name=self.__class__.__name__, freedom=freedom | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LinearSFunc(StaticFunc): | ||||
|     """The linear function that outputs f(x) = a * x + b.""" | ||||
|  | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(LinearSFunc, self).__init__(2, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x + self._params[1] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticSFunc(StaticFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(QuadraticSFunc, self).__init__(3, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CubicSFunc(StaticFunc): | ||||
|     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||
|  | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(CubicSFunc, self).__init__(4, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
|             + self._params[1] * x ** 2 | ||||
|             + self._params[2] * x | ||||
|             + self._params[3] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuarticSFunc(StaticFunc): | ||||
|     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||
|  | ||||
|     def __init__(self, params=None, xstr="x"): | ||||
|         super(QuarticSFunc, self).__init__(5, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
|             + self._params[1] * x ** 3 | ||||
|             + self._params[2] * x ** 2 | ||||
|             + self._params[3] * x | ||||
|             + self._params[4] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return ( | ||||
|             weights[0] * x ** 4 | ||||
|             + weights[1] * x ** 3 | ||||
|             + weights[2] * x ** 2 | ||||
|             + weights[3] * x | ||||
|             + weights[4] | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return ( | ||||
|             "{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format( | ||||
|                 name=self.__class__.__name__, | ||||
|                 a=self._params[0], | ||||
|                 b=self._params[1], | ||||
|                 c=self._params[2], | ||||
|                 d=self._params[3], | ||||
|                 e=self._params[3], | ||||
|                 x=self.xstr, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| ### advanced functions | ||||
|  | ||||
|  | ||||
| class ConstantFunc(StaticFunc): | ||||
|     """The constant function: f(x) = c.""" | ||||
|  | ||||
|     def __init__(self, constant, xstr="x"): | ||||
|         super(ConstantFunc, self).__init__(1, {0: constant}, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] | ||||
|  | ||||
|     def fit(self, **kwargs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) | ||||
|  | ||||
|  | ||||
| class ComposedSinSFunc(StaticFunc): | ||||
|     """The composed sin function that outputs: | ||||
|     f(x) = a * sin( b*x ) + c | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(ComposedSinSFunc, self).__init__(3, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         a = self._params[0] | ||||
|         b = self._params[1] | ||||
|         c = self._params[2] | ||||
|         return a * math.sin(b * x) + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class ComposedCosSFunc(StaticFunc): | ||||
|     """The composed sin function that outputs: | ||||
|     f(x) = a * cos( b*x ) + c | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params, xstr="x"): | ||||
|         super(ComposedCosSFunc, self).__init__(3, params, xstr) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         a = self._params[0] | ||||
|         b = self._params[1] | ||||
|         c = self._params[2] | ||||
|         return a * math.cos(b * x) + c | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
| @@ -1,13 +1,13 @@ | ||||
| import math | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
| from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc | ||||
| from .math_core import LinearSFunc | ||||
| from .math_core import LinearDFunc | ||||
| from .math_core import QuadraticDFunc, SinQuadraticDFunc | ||||
| from .math_core import ( | ||||
|     ConstantFunc, | ||||
|     ComposedSinFunc as SinFunc, | ||||
|     ComposedCosFunc as CosFunc, | ||||
|     ComposedSinSFunc as SinFunc, | ||||
|     ComposedCosSFunc as CosFunc, | ||||
| ) | ||||
| from .math_core import GaussianDGenerator | ||||
|  | ||||
| @@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
|  | ||||
| def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): | ||||
|     max_time = math.pi * 10 | ||||
|     if version == "v1": | ||||
|     if version.lower() == "v1": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|         data_generator = GaussianDGenerator( | ||||
| @@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicLinearFunc( | ||||
|         oracle_map = LinearDFunc( | ||||
|             params={ | ||||
|                 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),  # 2 sin(t) + 2.2 | ||||
|                 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),  # 1.5 sin(0.6t) + 1.8 | ||||
| @@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|             data_generator, oracle_map, time_generator, num_per_task | ||||
|         ) | ||||
|     elif version == "v2": | ||||
|         dynamic_env.set_regression() | ||||
|     elif version.lower() == "v2": | ||||
|         mean_generator = ConstantFunc(0) | ||||
|         std_generator = ConstantFunc(1) | ||||
|         data_generator = GaussianDGenerator( | ||||
| @@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicQuadraticFunc( | ||||
|         oracle_map = QuadraticDFunc( | ||||
|             params={ | ||||
|                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||
|                 2: ConstantFunc(0), | ||||
|                 0: LinearSFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||
|                 1: ConstantFunc(0), | ||||
|                 2: CosFunc(params={0: 4.0, 1: 10, 2: 0}),  # 4 * cos(10 * t) | ||||
|             } | ||||
|         ) | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|             data_generator, oracle_map, time_generator, num_per_task | ||||
|         ) | ||||
|         dynamic_env.set_regression() | ||||
|     elif version.lower() == "v3": | ||||
|         mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0})  # sin(t) | ||||
|         std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1})  # 0.5 cos(t) + 1 | ||||
| @@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = DynamicSinQuadraticFunc( | ||||
|         oracle_map = SinQuadraticDFunc( | ||||
|             params={ | ||||
|                 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}),  # 0.5 cos(t) + 1 | ||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||
| @@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|             data_generator, oracle_map, time_generator, num_per_task | ||||
|         ) | ||||
|         dynamic_env.set_regression() | ||||
|     elif version.lower() == "v4": | ||||
|         dynamic_env.set_classification(2) | ||||
|     else: | ||||
|         raise ValueError("Unknown version: {:}".format(version)) | ||||
|     return dynamic_env | ||||
|   | ||||
| @@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset): | ||||
|         self._meta_info["task"] = "classification" | ||||
|         self._meta_info["num_classes"] = int(num_classes) | ||||
|  | ||||
|     @property | ||||
|     def oracle_map(self): | ||||
|         return self._oracle_map | ||||
|  | ||||
|     @property | ||||
|     def meta_info(self): | ||||
|         return self._meta_info | ||||
|   | ||||
		Reference in New Issue
	
	Block a user