Re-organize GeMOSA

This commit is contained in:
D-X-Y 2021-05-27 15:44:01 +08:00
parent 8961215416
commit 6da60664f5
10 changed files with 354 additions and 350 deletions

View File

@ -1,10 +1,9 @@
#####################################################
# Learning to Generate Model One Step Ahead #
#####################################################
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
from meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
def online_evaluate(
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
):
logger.log("Online evaluate: {:}".format(env))
loss_meter = AverageMeter()
w_containers = dict()
@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
with torch.no_grad():
meta_model.eval()
base_model.eval()
[future_container], time_embeds = meta_model(
future_time.to(args.device).view(-1), None, False
future_time_embed = meta_model.gen_time_embed(
future_time.to(args.device).view(-1)
)
[future_container] = meta_model.gen_model(future_time_embed)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": time_embeds, "loss": future_loss.item()},
)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
else:
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
idx, len(env), future_loss.item()
@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
optimizer.zero_grad()
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
# future loss
total_future_losses, total_present_losses = [], []
future_containers, _ = meta_model(
None, generated_time_embeds[batch_indexes], False
)
present_containers, _ = meta_model(
None, meta_model.super_meta_embed[batch_indexes], False
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
present_containers = meta_model.gen_model(
meta_model.super_meta_embed[batch_indexes]
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
@ -216,13 +220,34 @@ def main(args):
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
"""
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
"""
_, test_loss_meter_adapt_v1 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, False
)
_, test_loss_meter_adapt_v2 = online_evaluate(
valid_env, meta_model, base_model, criterion, args, logger, False, True
)
logger.log(
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
test_loss_meter_adapt_v1
)
)
logger.log(
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
test_loss_meter_adapt_v2
)
)
save_checkpoint(
{"all_w_containers": w_containers},
{
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)

View File

@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers, time_embeds
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
_, time_embed = self(timestamp.view(1), None)
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = criterion(new_param, time_embed)
[container], time_embed = self(None, new_param.view(1, -1))
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:

View File

@ -191,6 +191,8 @@ def visualize_env(save_dir, version):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):

View File

@ -1,92 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import FitFunc
from .math_base_funcs import QuadraticFunc
from .math_base_funcs import QuarticFunc
class ConstantFunc(FitFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant=None, xstr="x"):
param = dict()
param[0] = constant
super(ConstantFunc, self).__init__(0, None, param, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0]
def fit(self, **kwargs):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
class ComposedSinFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedCosFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * cos( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedCosFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.cos(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@ -5,34 +5,33 @@ import math
import abc
import copy
import numpy as np
import torch
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
class MathFunc(abc.ABC):
"""The math function -- a virtual class defining some APIs."""
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
def __init__(self, freedom: int, params=None, xstr="x"):
# initialize as empty
self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
if list_of_points is not None and params is not None:
raise ValueError("list_of_points and params can not be set simultaneously")
if list_of_points is not None:
self.fit(list_of_points=list_of_points)
if params is not None:
self.set(params)
self._xstr = str(xstr)
self._skip_check = True
def set(self, params):
self._params = copy.deepcopy(params)
for key in range(self._freedom):
param = copy.deepcopy(params[key])
self._params[key] = param
def check_valid(self):
# for key, value in self._params.items():
for key in range(self._freedom):
value = self._params[key]
if value is None:
raise ValueError("The {:} is None".format(key))
if not self._skip_check:
for key in range(self._freedom):
value = self._params[key]
if value is None:
raise ValueError("The {:} is None".format(key))
@property
def xstr(self):
@ -45,7 +44,8 @@ class FitFunc(abc.ABC):
def __call__(self, x):
raise NotImplementedError
def noise_call(self, x, std=0.1):
@abc.abstractmethod
def noise_call(self, x, std):
clean_y = self.__call__(x)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
@ -53,169 +53,7 @@ class FitFunc(abc.ABC):
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
@abc.abstractmethod
def _getitem(self, x):
raise NotImplementedError
def fit(self, **kwargs):
list_of_points = kwargs["list_of_points"]
max_iter, lr_max, verbose = (
kwargs.get("max_iter", 900),
kwargs.get("lr_max", 1.0),
kwargs.get("verbose", False),
)
with torch.no_grad():
data = torch.Tensor(list_of_points).type(torch.float32)
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
data.shape
)
x, y = data[:, 0], data[:, 1]
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(max_iter * 0.25),
int(max_iter * 0.5),
int(max_iter * 0.75),
],
gamma=0.1,
)
if verbose:
print("The optimizer: {:}".format(optimizer))
best_loss = None
for _iter in range(max_iter):
y_hat = self._getitem(x, weights)
loss = torch.mean(torch.abs(y - y_hat))
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if verbose:
print(
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
_iter, max_iter, loss.item()
)
)
# Update the params
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
for i in range(self._freedom):
self._params[i] = weights[i].item()
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x + self._params[1]
def _getitem(self, x, weights):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class CubicFunc(FitFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, list_of_points=None):
super(CubicFunc, self).__init__(4, list_of_points)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)
class QuarticFunc(FitFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, list_of_points=None):
super(QuarticFunc, self).__init__(5, list_of_points)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
)

View File

@ -1,10 +1,14 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_dynamic_funcs import DynamicSinQuadraticFunc
from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
from .math_static_funcs import (
LinearSFunc,
QuadraticSFunc,
CubicSFunc,
QuarticSFunc,
ConstantFunc,
ComposedSinSFunc,
ComposedCosSFunc,
)
from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc
from .math_dynamic_generator import GaussianDGenerator

View File

@ -6,23 +6,17 @@ import abc
import copy
import numpy as np
from .math_base_funcs import FitFunc
from .math_base_funcs import MathFunc
class DynamicFunc(FitFunc):
"""The dynamic quadratic function, where each param is a function."""
class DynamicFunc(MathFunc):
"""The dynamic function, where each param is a function."""
def __init__(self, freedom: int, params=None, xstr="x"):
if params is not None:
for param in params:
param.reset_xstr("t") if isinstance(param, FitFunc) else None
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
def __call__(self, x, timestamp):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
for key, param in params.items():
param.reset_xstr("t") if isinstance(param, MathFunc) else None
super(DynamicFunc, self).__init__(freedom, params, xstr)
def noise_call(self, x, timestamp, std):
clean_y = self.__call__(x, timestamp)
@ -33,13 +27,13 @@ class DynamicFunc(FitFunc):
return noise_y
class DynamicLinearFunc(DynamicFunc):
class LinearDFunc(DynamicFunc):
"""The dynamic linear function that outputs f(x) = a * x + b.
The a and b is a function of timestamp.
"""
def __init__(self, params=None, xstr="x"):
super(DynamicLinearFunc, self).__init__(3, params, xstr)
def __init__(self, params, xstr="x"):
super(LinearDFunc, self).__init__(2, params, xstr)
def __call__(self, x, timestamp):
a = self._params[0](timestamp)
@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc):
)
class DynamicQuadraticFunc(DynamicFunc):
class QuadraticDFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicQuadraticFunc, self).__init__(3, params)
def __init__(self, params, xstr="x"):
super(QuadraticDFunc, self).__init__(3, params)
def __call__(
self,
x,
):
def __call__(self, x, timestamp):
self.check_valid()
a = self._params[0](timestamp)
b = self._params[1](timestamp)
@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc):
return a * x * x + b * x + c
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class DynamicSinQuadraticFunc(DynamicFunc):
class SinQuadraticDFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
The a, b, and c is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicSinQuadraticFunc, self).__init__(3, params)
super(SinQuadraticDFunc, self).__init__(3, params)
def __call__(
self,
x,
):
def __call__(self, x, timestamp):
self.check_valid()
a = self._params[0](timestamp)
b = self._params[1](timestamp)
c = self._params[2](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return math.sin(a * x * x + b * x + c)
return np.sin(a * x * x + b * x + c)
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@ -0,0 +1,225 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from .math_base_funcs import MathFunc
class StaticFunc(MathFunc):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, params=None, xstr="x"):
super(StaticFunc, self).__init__(freedom, params, xstr)
@abc.abstractmethod
def __call__(self, x):
raise NotImplementedError
def noise_call(self, x, std):
clean_y = self.__call__(x)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
else:
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class LinearSFunc(StaticFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, params=None, xstr="x"):
super(LinearSFunc, self).__init__(2, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x + self._params[1]
def _getitem(self, x, weights):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticSFunc(StaticFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, params=None, xstr="x"):
super(QuadraticSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class CubicSFunc(StaticFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, params=None, xstr="x"):
super(CubicSFunc, self).__init__(4, params, xstr)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)
class QuarticSFunc(StaticFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, params=None, xstr="x"):
super(QuarticSFunc, self).__init__(5, params, xstr)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return (
"{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
x=self.xstr,
)
)
### advanced functions
class ConstantFunc(StaticFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant, xstr="x"):
super(ConstantFunc, self).__init__(1, {0: constant}, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0]
def fit(self, **kwargs):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
class ComposedSinSFunc(StaticFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedCosSFunc(StaticFunc):
"""The composed sin function that outputs:
f(x) = a * cos( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedCosSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.cos(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@ -1,13 +1,13 @@
import math
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc
from .math_core import LinearSFunc
from .math_core import LinearDFunc
from .math_core import QuadraticDFunc, SinQuadraticDFunc
from .math_core import (
ConstantFunc,
ComposedSinFunc as SinFunc,
ComposedCosFunc as CosFunc,
ComposedSinSFunc as SinFunc,
ComposedCosSFunc as CosFunc,
)
from .math_core import GaussianDGenerator
@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
max_time = math.pi * 10
if version == "v1":
if version.lower() == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicLinearFunc(
oracle_map = LinearDFunc(
params={
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
elif version == "v2":
dynamic_env.set_regression()
elif version.lower() == "v2":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicQuadraticFunc(
oracle_map = QuadraticDFunc(
params={
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
2: ConstantFunc(0),
0: LinearSFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: ConstantFunc(0),
2: CosFunc(params={0: 4.0, 1: 10, 2: 0}), # 4 * cos(10 * t)
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
dynamic_env.set_regression()
elif version.lower() == "v3":
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicSinQuadraticFunc(
oracle_map = SinQuadraticDFunc(
params={
0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
dynamic_env.set_regression()
elif version.lower() == "v4":
dynamic_env.set_classification(2)
else:
raise ValueError("Unknown version: {:}".format(version))
return dynamic_env

View File

@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset):
self._meta_info["task"] = "classification"
self._meta_info["num_classes"] = int(num_classes)
@property
def oracle_map(self):
return self._oracle_map
@property
def meta_info(self):
return self._meta_info