Re-organize GeMOSA
This commit is contained in:
parent
8961215416
commit
6da60664f5
@ -1,10 +1,9 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Learning to Generate Model One Step Ahead #
|
# Learning to Generate Model One Step Ahead #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
|
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
|
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
|
||||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
|
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
|
|||||||
from meta_model import MetaModelV1
|
from meta_model import MetaModelV1
|
||||||
|
|
||||||
|
|
||||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
def online_evaluate(
|
||||||
|
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
||||||
|
):
|
||||||
logger.log("Online evaluate: {:}".format(env))
|
logger.log("Online evaluate: {:}".format(env))
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
w_containers = dict()
|
w_containers = dict()
|
||||||
@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
base_model.eval()
|
base_model.eval()
|
||||||
[future_container], time_embeds = meta_model(
|
future_time_embed = meta_model.gen_time_embed(
|
||||||
future_time.to(args.device).view(-1), None, False
|
future_time.to(args.device).view(-1)
|
||||||
)
|
)
|
||||||
|
[future_container] = meta_model.gen_model(future_time_embed)
|
||||||
if save:
|
if save:
|
||||||
w_containers[idx] = future_container.no_grad_clone()
|
w_containers[idx] = future_container.no_grad_clone()
|
||||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
loss_meter.update(future_loss.item())
|
loss_meter.update(future_loss.item())
|
||||||
refine, post_refine_loss = meta_model.adapt(
|
if easy_adapt:
|
||||||
base_model,
|
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||||
criterion,
|
refine, post_refine_loss = False, -1
|
||||||
future_time.item(),
|
else:
|
||||||
future_x,
|
refine, post_refine_loss = meta_model.adapt(
|
||||||
future_y,
|
base_model,
|
||||||
args.refine_lr,
|
criterion,
|
||||||
args.refine_epochs,
|
future_time.item(),
|
||||||
{"param": time_embeds, "loss": future_loss.item()},
|
future_x,
|
||||||
)
|
future_y,
|
||||||
|
args.refine_lr,
|
||||||
|
args.refine_epochs,
|
||||||
|
{"param": future_time_embed, "loss": future_loss.item()},
|
||||||
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||||
idx, len(env), future_loss.item()
|
idx, len(env), future_loss.item()
|
||||||
@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
)
|
)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
|
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
|
||||||
|
|
||||||
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
||||||
|
|
||||||
@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
)
|
)
|
||||||
# future loss
|
# future loss
|
||||||
total_future_losses, total_present_losses = [], []
|
total_future_losses, total_present_losses = [], []
|
||||||
future_containers, _ = meta_model(
|
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
|
||||||
None, generated_time_embeds[batch_indexes], False
|
present_containers = meta_model.gen_model(
|
||||||
)
|
meta_model.super_meta_embed[batch_indexes]
|
||||||
present_containers, _ = meta_model(
|
|
||||||
None, meta_model.super_meta_embed[batch_indexes], False
|
|
||||||
)
|
)
|
||||||
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
|
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
|
||||||
_, (inputs, targets) = xenv(time_step)
|
_, (inputs, targets) = xenv(time_step)
|
||||||
@ -216,13 +220,34 @@ def main(args):
|
|||||||
# try to evaluate once
|
# try to evaluate once
|
||||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||||
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||||
|
"""
|
||||||
w_containers, loss_meter = online_evaluate(
|
w_containers, loss_meter = online_evaluate(
|
||||||
all_env, meta_model, base_model, criterion, args, logger, True
|
all_env, meta_model, base_model, criterion, args, logger, True
|
||||||
)
|
)
|
||||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||||
|
"""
|
||||||
|
_, test_loss_meter_adapt_v1 = online_evaluate(
|
||||||
|
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
||||||
|
)
|
||||||
|
_, test_loss_meter_adapt_v2 = online_evaluate(
|
||||||
|
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
||||||
|
test_loss_meter_adapt_v1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
||||||
|
test_loss_meter_adapt_v2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{"all_w_containers": w_containers},
|
{
|
||||||
|
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
||||||
|
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
||||||
|
},
|
||||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
batch_containers.append(
|
batch_containers.append(
|
||||||
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
|
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
|
||||||
)
|
)
|
||||||
return batch_containers, time_embeds
|
return batch_containers
|
||||||
|
|
||||||
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
def forward_candidate(self, input):
|
def forward_candidate(self, input):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def easy_adapt(self, timestamp, time_embed):
|
||||||
|
with torch.no_grad():
|
||||||
|
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
|
||||||
|
self.replace_append_learnt(None, None)
|
||||||
|
self.append_fixed(timestamp, time_embed)
|
||||||
|
|
||||||
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
||||||
distance = self.get_closest_meta_distance(timestamp)
|
distance = self.get_closest_meta_distance(timestamp)
|
||||||
if distance + self._interval * 1e-2 <= self._interval:
|
if distance + self._interval * 1e-2 <= self._interval:
|
||||||
@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
|
|||||||
best_new_param = new_param.detach().clone()
|
best_new_param = new_param.detach().clone()
|
||||||
for iepoch in range(epochs):
|
for iepoch in range(epochs):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
_, time_embed = self(timestamp.view(1), None)
|
time_embed = self.gen_time_embed(timestamp.view(1))
|
||||||
match_loss = criterion(new_param, time_embed)
|
match_loss = criterion(new_param, time_embed)
|
||||||
|
|
||||||
[container], time_embed = self(None, new_param.view(1, -1))
|
[container] = self.gen_model(new_param.view(1, -1))
|
||||||
y_hat = base_model.forward_with_container(x, container)
|
y_hat = base_model.forward_with_container(x, container)
|
||||||
meta_loss = criterion(y_hat, y)
|
meta_loss = criterion(y_hat, y)
|
||||||
loss = meta_loss + match_loss
|
loss = meta_loss + match_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
|
||||||
if meta_loss.item() < best_loss:
|
if meta_loss.item() < best_loss:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
best_loss = meta_loss.item()
|
best_loss = meta_loss.item()
|
||||||
best_new_param = new_param.detach().clone()
|
best_new_param = new_param.detach().clone()
|
||||||
with torch.no_grad():
|
self.easy_adapt(timestamp, best_new_param)
|
||||||
self.replace_append_learnt(None, None)
|
|
||||||
self.append_fixed(timestamp, best_new_param)
|
|
||||||
return True, best_loss
|
return True, best_loss
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
@ -191,6 +191,8 @@ def visualize_env(save_dir, version):
|
|||||||
allxs.append(allx)
|
allxs.append(allx)
|
||||||
allys.append(ally)
|
allys.append(ally)
|
||||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||||
|
print("env: {:}".format(dynamic_env))
|
||||||
|
print("oracle_map: {:}".format(dynamic_env.oracle_map))
|
||||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
|
@ -1,92 +0,0 @@
|
|||||||
#####################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
|
||||||
#####################################################
|
|
||||||
import math
|
|
||||||
import abc
|
|
||||||
import copy
|
|
||||||
import numpy as np
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
|
|
||||||
from .math_base_funcs import FitFunc
|
|
||||||
from .math_base_funcs import QuadraticFunc
|
|
||||||
from .math_base_funcs import QuarticFunc
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantFunc(FitFunc):
|
|
||||||
"""The constant function: f(x) = c."""
|
|
||||||
|
|
||||||
def __init__(self, constant=None, xstr="x"):
|
|
||||||
param = dict()
|
|
||||||
param[0] = constant
|
|
||||||
super(ConstantFunc, self).__init__(0, None, param, xstr)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
return self._params[0]
|
|
||||||
|
|
||||||
def fit(self, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
|
|
||||||
|
|
||||||
|
|
||||||
class ComposedSinFunc(FitFunc):
|
|
||||||
"""The composed sin function that outputs:
|
|
||||||
f(x) = a * sin( b*x ) + c
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params, xstr="x"):
|
|
||||||
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
a = self._params[0]
|
|
||||||
b = self._params[1]
|
|
||||||
c = self._params[2]
|
|
||||||
return a * math.sin(b * x) + c
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
c=self._params[2],
|
|
||||||
x=self.xstr,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ComposedCosFunc(FitFunc):
|
|
||||||
"""The composed sin function that outputs:
|
|
||||||
f(x) = a * cos( b*x ) + c
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params, xstr="x"):
|
|
||||||
super(ComposedCosFunc, self).__init__(3, None, params, xstr)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
a = self._params[0]
|
|
||||||
b = self._params[1]
|
|
||||||
c = self._params[2]
|
|
||||||
return a * math.cos(b * x) + c
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
c=self._params[2],
|
|
||||||
x=self.xstr,
|
|
||||||
)
|
|
@ -5,34 +5,33 @@ import math
|
|||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class FitFunc(abc.ABC):
|
class MathFunc(abc.ABC):
|
||||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
"""The math function -- a virtual class defining some APIs."""
|
||||||
|
|
||||||
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
|
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||||
|
# initialize as empty
|
||||||
self._params = dict()
|
self._params = dict()
|
||||||
for i in range(freedom):
|
for i in range(freedom):
|
||||||
self._params[i] = None
|
self._params[i] = None
|
||||||
self._freedom = freedom
|
self._freedom = freedom
|
||||||
if list_of_points is not None and params is not None:
|
|
||||||
raise ValueError("list_of_points and params can not be set simultaneously")
|
|
||||||
if list_of_points is not None:
|
|
||||||
self.fit(list_of_points=list_of_points)
|
|
||||||
if params is not None:
|
if params is not None:
|
||||||
self.set(params)
|
self.set(params)
|
||||||
self._xstr = str(xstr)
|
self._xstr = str(xstr)
|
||||||
|
self._skip_check = True
|
||||||
|
|
||||||
def set(self, params):
|
def set(self, params):
|
||||||
self._params = copy.deepcopy(params)
|
for key in range(self._freedom):
|
||||||
|
param = copy.deepcopy(params[key])
|
||||||
|
self._params[key] = param
|
||||||
|
|
||||||
def check_valid(self):
|
def check_valid(self):
|
||||||
# for key, value in self._params.items():
|
if not self._skip_check:
|
||||||
for key in range(self._freedom):
|
for key in range(self._freedom):
|
||||||
value = self._params[key]
|
value = self._params[key]
|
||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError("The {:} is None".format(key))
|
raise ValueError("The {:} is None".format(key))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def xstr(self):
|
def xstr(self):
|
||||||
@ -45,7 +44,8 @@ class FitFunc(abc.ABC):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def noise_call(self, x, std=0.1):
|
@abc.abstractmethod
|
||||||
|
def noise_call(self, x, std):
|
||||||
clean_y = self.__call__(x)
|
clean_y = self.__call__(x)
|
||||||
if isinstance(clean_y, np.ndarray):
|
if isinstance(clean_y, np.ndarray):
|
||||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||||
@ -53,169 +53,7 @@ class FitFunc(abc.ABC):
|
|||||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||||
return noise_y
|
return noise_y
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _getitem(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def fit(self, **kwargs):
|
|
||||||
list_of_points = kwargs["list_of_points"]
|
|
||||||
max_iter, lr_max, verbose = (
|
|
||||||
kwargs.get("max_iter", 900),
|
|
||||||
kwargs.get("lr_max", 1.0),
|
|
||||||
kwargs.get("verbose", False),
|
|
||||||
)
|
|
||||||
with torch.no_grad():
|
|
||||||
data = torch.Tensor(list_of_points).type(torch.float32)
|
|
||||||
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
|
|
||||||
data.shape
|
|
||||||
)
|
|
||||||
x, y = data[:, 0], data[:, 1]
|
|
||||||
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
|
|
||||||
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
|
|
||||||
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
||||||
optimizer,
|
|
||||||
milestones=[
|
|
||||||
int(max_iter * 0.25),
|
|
||||||
int(max_iter * 0.5),
|
|
||||||
int(max_iter * 0.75),
|
|
||||||
],
|
|
||||||
gamma=0.1,
|
|
||||||
)
|
|
||||||
if verbose:
|
|
||||||
print("The optimizer: {:}".format(optimizer))
|
|
||||||
|
|
||||||
best_loss = None
|
|
||||||
for _iter in range(max_iter):
|
|
||||||
y_hat = self._getitem(x, weights)
|
|
||||||
loss = torch.mean(torch.abs(y - y_hat))
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
|
|
||||||
_iter, max_iter, loss.item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Update the params
|
|
||||||
if best_loss is None or best_loss > loss.item():
|
|
||||||
best_loss = loss.item()
|
|
||||||
for i in range(self._freedom):
|
|
||||||
self._params[i] = weights[i].item()
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}(freedom={freedom})".format(
|
return "{name}(freedom={freedom})".format(
|
||||||
name=self.__class__.__name__, freedom=freedom
|
name=self.__class__.__name__, freedom=freedom
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinearFunc(FitFunc):
|
|
||||||
"""The linear function that outputs f(x) = a * x + b."""
|
|
||||||
|
|
||||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
|
||||||
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
return self._params[0] * x + self._params[1]
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
return weights[0] * x + weights[1]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * {x} + {b})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
x=self.xstr,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QuadraticFunc(FitFunc):
|
|
||||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
|
||||||
|
|
||||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
|
||||||
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
return self._params[0] * x * x + self._params[1] * x + self._params[2]
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
c=self._params[2],
|
|
||||||
x=self.xstr,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CubicFunc(FitFunc):
|
|
||||||
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
|
|
||||||
|
|
||||||
def __init__(self, list_of_points=None):
|
|
||||||
super(CubicFunc, self).__init__(4, list_of_points)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
return (
|
|
||||||
self._params[0] * x ** 3
|
|
||||||
+ self._params[1] * x ** 2
|
|
||||||
+ self._params[2] * x
|
|
||||||
+ self._params[3]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
c=self._params[2],
|
|
||||||
d=self._params[3],
|
|
||||||
x=self.xstr,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QuarticFunc(FitFunc):
|
|
||||||
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
|
|
||||||
|
|
||||||
def __init__(self, list_of_points=None):
|
|
||||||
super(QuarticFunc, self).__init__(5, list_of_points)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
self.check_valid()
|
|
||||||
return (
|
|
||||||
self._params[0] * x ** 4
|
|
||||||
+ self._params[1] * x ** 3
|
|
||||||
+ self._params[2] * x ** 2
|
|
||||||
+ self._params[3] * x
|
|
||||||
+ self._params[4]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
return (
|
|
||||||
weights[0] * x ** 4
|
|
||||||
+ weights[1] * x ** 3
|
|
||||||
+ weights[2] * x ** 2
|
|
||||||
+ weights[3] * x
|
|
||||||
+ weights[4]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
|
|
||||||
name=self.__class__.__name__,
|
|
||||||
a=self._params[0],
|
|
||||||
b=self._params[1],
|
|
||||||
c=self._params[2],
|
|
||||||
d=self._params[3],
|
|
||||||
e=self._params[3],
|
|
||||||
)
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||||
#####################################################
|
#####################################################
|
||||||
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
|
from .math_static_funcs import (
|
||||||
from .math_dynamic_funcs import DynamicLinearFunc
|
LinearSFunc,
|
||||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
QuadraticSFunc,
|
||||||
from .math_dynamic_funcs import DynamicSinQuadraticFunc
|
CubicSFunc,
|
||||||
from .math_adv_funcs import ConstantFunc
|
QuarticSFunc,
|
||||||
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
|
ConstantFunc,
|
||||||
|
ComposedSinSFunc,
|
||||||
|
ComposedCosSFunc,
|
||||||
|
)
|
||||||
|
from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc
|
||||||
from .math_dynamic_generator import GaussianDGenerator
|
from .math_dynamic_generator import GaussianDGenerator
|
||||||
|
@ -6,23 +6,17 @@ import abc
|
|||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .math_base_funcs import FitFunc
|
from .math_base_funcs import MathFunc
|
||||||
|
|
||||||
|
|
||||||
class DynamicFunc(FitFunc):
|
class DynamicFunc(MathFunc):
|
||||||
"""The dynamic quadratic function, where each param is a function."""
|
"""The dynamic function, where each param is a function."""
|
||||||
|
|
||||||
def __init__(self, freedom: int, params=None, xstr="x"):
|
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||||
if params is not None:
|
if params is not None:
|
||||||
for param in params:
|
for key, param in params.items():
|
||||||
param.reset_xstr("t") if isinstance(param, FitFunc) else None
|
param.reset_xstr("t") if isinstance(param, MathFunc) else None
|
||||||
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
|
super(DynamicFunc, self).__init__(freedom, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x, timestamp):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def noise_call(self, x, timestamp, std):
|
def noise_call(self, x, timestamp, std):
|
||||||
clean_y = self.__call__(x, timestamp)
|
clean_y = self.__call__(x, timestamp)
|
||||||
@ -33,13 +27,13 @@ class DynamicFunc(FitFunc):
|
|||||||
return noise_y
|
return noise_y
|
||||||
|
|
||||||
|
|
||||||
class DynamicLinearFunc(DynamicFunc):
|
class LinearDFunc(DynamicFunc):
|
||||||
"""The dynamic linear function that outputs f(x) = a * x + b.
|
"""The dynamic linear function that outputs f(x) = a * x + b.
|
||||||
The a and b is a function of timestamp.
|
The a and b is a function of timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params=None, xstr="x"):
|
def __init__(self, params, xstr="x"):
|
||||||
super(DynamicLinearFunc, self).__init__(3, params, xstr)
|
super(LinearDFunc, self).__init__(2, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x, timestamp):
|
def __call__(self, x, timestamp):
|
||||||
a = self._params[0](timestamp)
|
a = self._params[0](timestamp)
|
||||||
@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DynamicQuadraticFunc(DynamicFunc):
|
class QuadraticDFunc(DynamicFunc):
|
||||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||||
The a, b, and c is a function of timestamp.
|
The a, b, and c is a function of timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params=None):
|
def __init__(self, params, xstr="x"):
|
||||||
super(DynamicQuadraticFunc, self).__init__(3, params)
|
super(QuadraticDFunc, self).__init__(3, params)
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, x, timestamp):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
):
|
|
||||||
self.check_valid()
|
self.check_valid()
|
||||||
a = self._params[0](timestamp)
|
a = self._params[0](timestamp)
|
||||||
b = self._params[1](timestamp)
|
b = self._params[1](timestamp)
|
||||||
@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc):
|
|||||||
return a * x * x + b * x + c
|
return a * x * x + b * x + c
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
c=self._params[2],
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DynamicSinQuadraticFunc(DynamicFunc):
|
class SinQuadraticDFunc(DynamicFunc):
|
||||||
"""The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
|
"""The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
|
||||||
The a, b, and c is a function of timestamp.
|
The a, b, and c is a function of timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params=None):
|
def __init__(self, params=None):
|
||||||
super(DynamicSinQuadraticFunc, self).__init__(3, params)
|
super(SinQuadraticDFunc, self).__init__(3, params)
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, x, timestamp):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
):
|
|
||||||
self.check_valid()
|
self.check_valid()
|
||||||
a = self._params[0](timestamp)
|
a = self._params[0](timestamp)
|
||||||
b = self._params[1](timestamp)
|
b = self._params[1](timestamp)
|
||||||
c = self._params[2](timestamp)
|
c = self._params[2](timestamp)
|
||||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||||
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
||||||
return math.sin(a * x * x + b * x + c)
|
return np.sin(a * x * x + b * x + c)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
c=self._params[2],
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
225
xautodl/datasets/math_static_funcs.py
Normal file
225
xautodl/datasets/math_static_funcs.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
import math
|
||||||
|
import abc
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .math_base_funcs import MathFunc
|
||||||
|
|
||||||
|
|
||||||
|
class StaticFunc(MathFunc):
|
||||||
|
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||||
|
|
||||||
|
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||||
|
super(StaticFunc, self).__init__(freedom, params, xstr)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __call__(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def noise_call(self, x, std):
|
||||||
|
clean_y = self.__call__(x)
|
||||||
|
if isinstance(clean_y, np.ndarray):
|
||||||
|
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||||
|
return noise_y
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(freedom={freedom})".format(
|
||||||
|
name=self.__class__.__name__, freedom=freedom
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearSFunc(StaticFunc):
|
||||||
|
"""The linear function that outputs f(x) = a * x + b."""
|
||||||
|
|
||||||
|
def __init__(self, params=None, xstr="x"):
|
||||||
|
super(LinearSFunc, self).__init__(2, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
return self._params[0] * x + self._params[1]
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
return weights[0] * x + weights[1]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * {x} + {b})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuadraticSFunc(StaticFunc):
|
||||||
|
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||||
|
|
||||||
|
def __init__(self, params=None, xstr="x"):
|
||||||
|
super(QuadraticSFunc, self).__init__(3, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
return self._params[0] * x * x + self._params[1] * x + self._params[2]
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CubicSFunc(StaticFunc):
|
||||||
|
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
|
||||||
|
|
||||||
|
def __init__(self, params=None, xstr="x"):
|
||||||
|
super(CubicSFunc, self).__init__(4, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
return (
|
||||||
|
self._params[0] * x ** 3
|
||||||
|
+ self._params[1] * x ** 2
|
||||||
|
+ self._params[2] * x
|
||||||
|
+ self._params[3]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
d=self._params[3],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuarticSFunc(StaticFunc):
|
||||||
|
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
|
||||||
|
|
||||||
|
def __init__(self, params=None, xstr="x"):
|
||||||
|
super(QuarticSFunc, self).__init__(5, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
return (
|
||||||
|
self._params[0] * x ** 4
|
||||||
|
+ self._params[1] * x ** 3
|
||||||
|
+ self._params[2] * x ** 2
|
||||||
|
+ self._params[3] * x
|
||||||
|
+ self._params[4]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
return (
|
||||||
|
weights[0] * x ** 4
|
||||||
|
+ weights[1] * x ** 3
|
||||||
|
+ weights[2] * x ** 2
|
||||||
|
+ weights[3] * x
|
||||||
|
+ weights[4]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
d=self._params[3],
|
||||||
|
e=self._params[3],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### advanced functions
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantFunc(StaticFunc):
|
||||||
|
"""The constant function: f(x) = c."""
|
||||||
|
|
||||||
|
def __init__(self, constant, xstr="x"):
|
||||||
|
super(ConstantFunc, self).__init__(1, {0: constant}, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
return self._params[0]
|
||||||
|
|
||||||
|
def fit(self, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
|
||||||
|
|
||||||
|
|
||||||
|
class ComposedSinSFunc(StaticFunc):
|
||||||
|
"""The composed sin function that outputs:
|
||||||
|
f(x) = a * sin( b*x ) + c
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, xstr="x"):
|
||||||
|
super(ComposedSinSFunc, self).__init__(3, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
a = self._params[0]
|
||||||
|
b = self._params[1]
|
||||||
|
c = self._params[2]
|
||||||
|
return a * math.sin(b * x) + c
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ComposedCosSFunc(StaticFunc):
|
||||||
|
"""The composed sin function that outputs:
|
||||||
|
f(x) = a * cos( b*x ) + c
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, xstr="x"):
|
||||||
|
super(ComposedCosSFunc, self).__init__(3, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
a = self._params[0]
|
||||||
|
b = self._params[1]
|
||||||
|
c = self._params[2]
|
||||||
|
return a * math.cos(b * x) + c
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
@ -1,13 +1,13 @@
|
|||||||
import math
|
import math
|
||||||
from .synthetic_utils import TimeStamp
|
from .synthetic_utils import TimeStamp
|
||||||
from .synthetic_env import SyntheticDEnv
|
from .synthetic_env import SyntheticDEnv
|
||||||
from .math_core import LinearFunc
|
from .math_core import LinearSFunc
|
||||||
from .math_core import DynamicLinearFunc
|
from .math_core import LinearDFunc
|
||||||
from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc
|
from .math_core import QuadraticDFunc, SinQuadraticDFunc
|
||||||
from .math_core import (
|
from .math_core import (
|
||||||
ConstantFunc,
|
ConstantFunc,
|
||||||
ComposedSinFunc as SinFunc,
|
ComposedSinSFunc as SinFunc,
|
||||||
ComposedCosFunc as CosFunc,
|
ComposedCosSFunc as CosFunc,
|
||||||
)
|
)
|
||||||
from .math_core import GaussianDGenerator
|
from .math_core import GaussianDGenerator
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
|||||||
|
|
||||||
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
|
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
|
||||||
max_time = math.pi * 10
|
max_time = math.pi * 10
|
||||||
if version == "v1":
|
if version.lower() == "v1":
|
||||||
mean_generator = ConstantFunc(0)
|
mean_generator = ConstantFunc(0)
|
||||||
std_generator = ConstantFunc(1)
|
std_generator = ConstantFunc(1)
|
||||||
data_generator = GaussianDGenerator(
|
data_generator = GaussianDGenerator(
|
||||||
@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
|||||||
time_generator = TimeStamp(
|
time_generator = TimeStamp(
|
||||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||||
)
|
)
|
||||||
oracle_map = DynamicLinearFunc(
|
oracle_map = LinearDFunc(
|
||||||
params={
|
params={
|
||||||
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
|
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
|
||||||
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
|
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
|
||||||
@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
|||||||
dynamic_env = SyntheticDEnv(
|
dynamic_env = SyntheticDEnv(
|
||||||
data_generator, oracle_map, time_generator, num_per_task
|
data_generator, oracle_map, time_generator, num_per_task
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
dynamic_env.set_regression()
|
||||||
|
elif version.lower() == "v2":
|
||||||
mean_generator = ConstantFunc(0)
|
mean_generator = ConstantFunc(0)
|
||||||
std_generator = ConstantFunc(1)
|
std_generator = ConstantFunc(1)
|
||||||
data_generator = GaussianDGenerator(
|
data_generator = GaussianDGenerator(
|
||||||
@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
|||||||
time_generator = TimeStamp(
|
time_generator = TimeStamp(
|
||||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||||
)
|
)
|
||||||
oracle_map = DynamicQuadraticFunc(
|
oracle_map = QuadraticDFunc(
|
||||||
params={
|
params={
|
||||||
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
|
0: LinearSFunc(params={0: 0.1, 1: 0}), # 0.1 * t
|
||||||
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
|
1: ConstantFunc(0),
|
||||||
2: ConstantFunc(0),
|
2: CosFunc(params={0: 4.0, 1: 10, 2: 0}), # 4 * cos(10 * t)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
dynamic_env = SyntheticDEnv(
|
dynamic_env = SyntheticDEnv(
|
||||||
data_generator, oracle_map, time_generator, num_per_task
|
data_generator, oracle_map, time_generator, num_per_task
|
||||||
)
|
)
|
||||||
|
dynamic_env.set_regression()
|
||||||
elif version.lower() == "v3":
|
elif version.lower() == "v3":
|
||||||
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
|
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
|
||||||
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
|
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
|
||||||
@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
|||||||
time_generator = TimeStamp(
|
time_generator = TimeStamp(
|
||||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||||
)
|
)
|
||||||
oracle_map = DynamicSinQuadraticFunc(
|
oracle_map = SinQuadraticDFunc(
|
||||||
params={
|
params={
|
||||||
0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1
|
0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1
|
||||||
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
|
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
|
||||||
@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
|||||||
dynamic_env = SyntheticDEnv(
|
dynamic_env = SyntheticDEnv(
|
||||||
data_generator, oracle_map, time_generator, num_per_task
|
data_generator, oracle_map, time_generator, num_per_task
|
||||||
)
|
)
|
||||||
|
dynamic_env.set_regression()
|
||||||
|
elif version.lower() == "v4":
|
||||||
|
dynamic_env.set_classification(2)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown version: {:}".format(version))
|
raise ValueError("Unknown version: {:}".format(version))
|
||||||
return dynamic_env
|
return dynamic_env
|
||||||
|
@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
self._meta_info["task"] = "classification"
|
self._meta_info["task"] = "classification"
|
||||||
self._meta_info["num_classes"] = int(num_classes)
|
self._meta_info["num_classes"] = int(num_classes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oracle_map(self):
|
||||||
|
return self._oracle_map
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def meta_info(self):
|
def meta_info(self):
|
||||||
return self._meta_info
|
return self._meta_info
|
||||||
|
Loading…
Reference in New Issue
Block a user