Re-organize GeMOSA
This commit is contained in:
parent
8961215416
commit
6da60664f5
@ -1,10 +1,9 @@
|
||||
#####################################################
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128
|
||||
# python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256
|
||||
# python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData
|
||||
from meta_model import MetaModelV1
|
||||
|
||||
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||
def online_evaluate(
|
||||
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
||||
):
|
||||
logger.log("Online evaluate: {:}".format(env))
|
||||
loss_meter = AverageMeter()
|
||||
w_containers = dict()
|
||||
@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F
|
||||
with torch.no_grad():
|
||||
meta_model.eval()
|
||||
base_model.eval()
|
||||
[future_container], time_embeds = meta_model(
|
||||
future_time.to(args.device).view(-1), None, False
|
||||
future_time_embed = meta_model.gen_time_embed(
|
||||
future_time.to(args.device).view(-1)
|
||||
)
|
||||
[future_container] = meta_model.gen_model(future_time_embed)
|
||||
if save:
|
||||
w_containers[idx] = future_container.no_grad_clone()
|
||||
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
loss_meter.update(future_loss.item())
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": time_embeds, "loss": future_loss.item()},
|
||||
)
|
||||
if easy_adapt:
|
||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||
refine, post_refine_loss = False, -1
|
||||
else:
|
||||
refine, post_refine_loss = meta_model.adapt(
|
||||
base_model,
|
||||
criterion,
|
||||
future_time.item(),
|
||||
future_x,
|
||||
future_y,
|
||||
args.refine_lr,
|
||||
args.refine_epochs,
|
||||
{"param": future_time_embed, "loss": future_loss.item()},
|
||||
)
|
||||
logger.log(
|
||||
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format(
|
||||
idx, len(env), future_loss.item()
|
||||
@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
|
||||
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
|
||||
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
|
||||
|
||||
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
||||
|
||||
@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
# future loss
|
||||
total_future_losses, total_present_losses = [], []
|
||||
future_containers, _ = meta_model(
|
||||
None, generated_time_embeds[batch_indexes], False
|
||||
)
|
||||
present_containers, _ = meta_model(
|
||||
None, meta_model.super_meta_embed[batch_indexes], False
|
||||
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
|
||||
present_containers = meta_model.gen_model(
|
||||
meta_model.super_meta_embed[batch_indexes]
|
||||
)
|
||||
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
|
||||
_, (inputs, targets) = xenv(time_step)
|
||||
@ -216,13 +220,34 @@ def main(args):
|
||||
# try to evaluate once
|
||||
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
|
||||
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
|
||||
"""
|
||||
w_containers, loss_meter = online_evaluate(
|
||||
all_env, meta_model, base_model, criterion, args, logger, True
|
||||
)
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
"""
|
||||
_, test_loss_meter_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
||||
)
|
||||
_, test_loss_meter_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v1
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v2
|
||||
)
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
{"all_w_containers": w_containers},
|
||||
{
|
||||
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
||||
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
||||
},
|
||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||
logger,
|
||||
)
|
||||
|
@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
|
||||
batch_containers.append(
|
||||
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
|
||||
)
|
||||
return batch_containers, time_embeds
|
||||
return batch_containers
|
||||
|
||||
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
||||
raise NotImplementedError
|
||||
@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
|
||||
def forward_candidate(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def easy_adapt(self, timestamp, time_embed):
|
||||
with torch.no_grad():
|
||||
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, time_embed)
|
||||
|
||||
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
||||
distance = self.get_closest_meta_distance(timestamp)
|
||||
if distance + self._interval * 1e-2 <= self._interval:
|
||||
@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
|
||||
best_new_param = new_param.detach().clone()
|
||||
for iepoch in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
_, time_embed = self(timestamp.view(1), None)
|
||||
time_embed = self.gen_time_embed(timestamp.view(1))
|
||||
match_loss = criterion(new_param, time_embed)
|
||||
|
||||
[container], time_embed = self(None, new_param.view(1, -1))
|
||||
[container] = self.gen_model(new_param.view(1, -1))
|
||||
y_hat = base_model.forward_with_container(x, container)
|
||||
meta_loss = criterion(y_hat, y)
|
||||
loss = meta_loss + match_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
||||
if meta_loss.item() < best_loss:
|
||||
with torch.no_grad():
|
||||
best_loss = meta_loss.item()
|
||||
best_new_param = new_param.detach().clone()
|
||||
with torch.no_grad():
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, best_new_param)
|
||||
self.easy_adapt(timestamp, best_new_param)
|
||||
return True, best_loss
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
@ -191,6 +191,8 @@ def visualize_env(save_dir, version):
|
||||
allxs.append(allx)
|
||||
allys.append(ally)
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
print("env: {:}".format(dynamic_env))
|
||||
print("oracle_map: {:}".format(dynamic_env.oracle_map))
|
||||
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
|
||||
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
|
@ -1,92 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from .math_base_funcs import FitFunc
|
||||
from .math_base_funcs import QuadraticFunc
|
||||
from .math_base_funcs import QuarticFunc
|
||||
|
||||
|
||||
class ConstantFunc(FitFunc):
|
||||
"""The constant function: f(x) = c."""
|
||||
|
||||
def __init__(self, constant=None, xstr="x"):
|
||||
param = dict()
|
||||
param[0] = constant
|
||||
super(ConstantFunc, self).__init__(0, None, param, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0]
|
||||
|
||||
def fit(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
|
||||
|
||||
|
||||
class ComposedSinFunc(FitFunc):
|
||||
"""The composed sin function that outputs:
|
||||
f(x) = a * sin( b*x ) + c
|
||||
"""
|
||||
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
a = self._params[0]
|
||||
b = self._params[1]
|
||||
c = self._params[2]
|
||||
return a * math.sin(b * x) + c
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class ComposedCosFunc(FitFunc):
|
||||
"""The composed sin function that outputs:
|
||||
f(x) = a * cos( b*x ) + c
|
||||
"""
|
||||
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(ComposedCosFunc, self).__init__(3, None, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
a = self._params[0]
|
||||
b = self._params[1]
|
||||
c = self._params[2]
|
||||
return a * math.cos(b * x) + c
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
@ -5,34 +5,33 @@ import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class FitFunc(abc.ABC):
|
||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
class MathFunc(abc.ABC):
|
||||
"""The math function -- a virtual class defining some APIs."""
|
||||
|
||||
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
|
||||
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||
# initialize as empty
|
||||
self._params = dict()
|
||||
for i in range(freedom):
|
||||
self._params[i] = None
|
||||
self._freedom = freedom
|
||||
if list_of_points is not None and params is not None:
|
||||
raise ValueError("list_of_points and params can not be set simultaneously")
|
||||
if list_of_points is not None:
|
||||
self.fit(list_of_points=list_of_points)
|
||||
if params is not None:
|
||||
self.set(params)
|
||||
self._xstr = str(xstr)
|
||||
self._skip_check = True
|
||||
|
||||
def set(self, params):
|
||||
self._params = copy.deepcopy(params)
|
||||
for key in range(self._freedom):
|
||||
param = copy.deepcopy(params[key])
|
||||
self._params[key] = param
|
||||
|
||||
def check_valid(self):
|
||||
# for key, value in self._params.items():
|
||||
for key in range(self._freedom):
|
||||
value = self._params[key]
|
||||
if value is None:
|
||||
raise ValueError("The {:} is None".format(key))
|
||||
if not self._skip_check:
|
||||
for key in range(self._freedom):
|
||||
value = self._params[key]
|
||||
if value is None:
|
||||
raise ValueError("The {:} is None".format(key))
|
||||
|
||||
@property
|
||||
def xstr(self):
|
||||
@ -45,7 +44,8 @@ class FitFunc(abc.ABC):
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def noise_call(self, x, std=0.1):
|
||||
@abc.abstractmethod
|
||||
def noise_call(self, x, std):
|
||||
clean_y = self.__call__(x)
|
||||
if isinstance(clean_y, np.ndarray):
|
||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||
@ -53,169 +53,7 @@ class FitFunc(abc.ABC):
|
||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||
return noise_y
|
||||
|
||||
@abc.abstractmethod
|
||||
def _getitem(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, **kwargs):
|
||||
list_of_points = kwargs["list_of_points"]
|
||||
max_iter, lr_max, verbose = (
|
||||
kwargs.get("max_iter", 900),
|
||||
kwargs.get("lr_max", 1.0),
|
||||
kwargs.get("verbose", False),
|
||||
)
|
||||
with torch.no_grad():
|
||||
data = torch.Tensor(list_of_points).type(torch.float32)
|
||||
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
|
||||
data.shape
|
||||
)
|
||||
x, y = data[:, 0], data[:, 1]
|
||||
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
|
||||
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
|
||||
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[
|
||||
int(max_iter * 0.25),
|
||||
int(max_iter * 0.5),
|
||||
int(max_iter * 0.75),
|
||||
],
|
||||
gamma=0.1,
|
||||
)
|
||||
if verbose:
|
||||
print("The optimizer: {:}".format(optimizer))
|
||||
|
||||
best_loss = None
|
||||
for _iter in range(max_iter):
|
||||
y_hat = self._getitem(x, weights)
|
||||
loss = torch.mean(torch.abs(y - y_hat))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if verbose:
|
||||
print(
|
||||
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
|
||||
_iter, max_iter, loss.item()
|
||||
)
|
||||
)
|
||||
# Update the params
|
||||
if best_loss is None or best_loss > loss.item():
|
||||
best_loss = loss.item()
|
||||
for i in range(self._freedom):
|
||||
self._params[i] = weights[i].item()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(freedom={freedom})".format(
|
||||
name=self.__class__.__name__, freedom=freedom
|
||||
)
|
||||
|
||||
|
||||
class LinearFunc(FitFunc):
|
||||
"""The linear function that outputs f(x) = a * x + b."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x + self._params[1]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x + weights[1]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x} + {b})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class QuadraticFunc(FitFunc):
|
||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x * x + self._params[1] * x + self._params[2]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class CubicFunc(FitFunc):
|
||||
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(CubicFunc, self).__init__(4, list_of_points)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 3
|
||||
+ self._params[1] * x ** 2
|
||||
+ self._params[2] * x
|
||||
+ self._params[3]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class QuarticFunc(FitFunc):
|
||||
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(QuarticFunc, self).__init__(5, list_of_points)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 4
|
||||
+ self._params[1] * x ** 3
|
||||
+ self._params[2] * x ** 2
|
||||
+ self._params[3] * x
|
||||
+ self._params[4]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return (
|
||||
weights[0] * x ** 4
|
||||
+ weights[1] * x ** 3
|
||||
+ weights[2] * x ** 2
|
||||
+ weights[3] * x
|
||||
+ weights[4]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
e=self._params[3],
|
||||
)
|
||||
|
@ -1,10 +1,14 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_dynamic_funcs import DynamicSinQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
|
||||
from .math_static_funcs import (
|
||||
LinearSFunc,
|
||||
QuadraticSFunc,
|
||||
CubicSFunc,
|
||||
QuarticSFunc,
|
||||
ConstantFunc,
|
||||
ComposedSinSFunc,
|
||||
ComposedCosSFunc,
|
||||
)
|
||||
from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc
|
||||
from .math_dynamic_generator import GaussianDGenerator
|
||||
|
@ -6,23 +6,17 @@ import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from .math_base_funcs import FitFunc
|
||||
from .math_base_funcs import MathFunc
|
||||
|
||||
|
||||
class DynamicFunc(FitFunc):
|
||||
"""The dynamic quadratic function, where each param is a function."""
|
||||
class DynamicFunc(MathFunc):
|
||||
"""The dynamic function, where each param is a function."""
|
||||
|
||||
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||
if params is not None:
|
||||
for param in params:
|
||||
param.reset_xstr("t") if isinstance(param, FitFunc) else None
|
||||
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
|
||||
|
||||
def __call__(self, x, timestamp):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
for key, param in params.items():
|
||||
param.reset_xstr("t") if isinstance(param, MathFunc) else None
|
||||
super(DynamicFunc, self).__init__(freedom, params, xstr)
|
||||
|
||||
def noise_call(self, x, timestamp, std):
|
||||
clean_y = self.__call__(x, timestamp)
|
||||
@ -33,13 +27,13 @@ class DynamicFunc(FitFunc):
|
||||
return noise_y
|
||||
|
||||
|
||||
class DynamicLinearFunc(DynamicFunc):
|
||||
class LinearDFunc(DynamicFunc):
|
||||
"""The dynamic linear function that outputs f(x) = a * x + b.
|
||||
The a and b is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None, xstr="x"):
|
||||
super(DynamicLinearFunc, self).__init__(3, params, xstr)
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(LinearDFunc, self).__init__(2, params, xstr)
|
||||
|
||||
def __call__(self, x, timestamp):
|
||||
a = self._params[0](timestamp)
|
||||
@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc):
|
||||
)
|
||||
|
||||
|
||||
class DynamicQuadraticFunc(DynamicFunc):
|
||||
class QuadraticDFunc(DynamicFunc):
|
||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||
The a, b, and c is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(DynamicQuadraticFunc, self).__init__(3, params)
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(QuadraticDFunc, self).__init__(3, params)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
def __call__(self, x, timestamp):
|
||||
self.check_valid()
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc):
|
||||
return a * x * x + b * x + c
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class DynamicSinQuadraticFunc(DynamicFunc):
|
||||
class SinQuadraticDFunc(DynamicFunc):
|
||||
"""The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
|
||||
The a, b, and c is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(DynamicSinQuadraticFunc, self).__init__(3, params)
|
||||
super(SinQuadraticDFunc, self).__init__(3, params)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
def __call__(self, x, timestamp):
|
||||
self.check_valid()
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
c = self._params[2](timestamp)
|
||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
||||
return math.sin(a * x * x + b * x + c)
|
||||
return np.sin(a * x * x + b * x + c)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
225
xautodl/datasets/math_static_funcs.py
Normal file
225
xautodl/datasets/math_static_funcs.py
Normal file
@ -0,0 +1,225 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
from .math_base_funcs import MathFunc
|
||||
|
||||
|
||||
class StaticFunc(MathFunc):
|
||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||
super(StaticFunc, self).__init__(freedom, params, xstr)
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def noise_call(self, x, std):
|
||||
clean_y = self.__call__(x)
|
||||
if isinstance(clean_y, np.ndarray):
|
||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||
else:
|
||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||
return noise_y
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(freedom={freedom})".format(
|
||||
name=self.__class__.__name__, freedom=freedom
|
||||
)
|
||||
|
||||
|
||||
class LinearSFunc(StaticFunc):
|
||||
"""The linear function that outputs f(x) = a * x + b."""
|
||||
|
||||
def __init__(self, params=None, xstr="x"):
|
||||
super(LinearSFunc, self).__init__(2, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x + self._params[1]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x + weights[1]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x} + {b})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class QuadraticSFunc(StaticFunc):
|
||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, params=None, xstr="x"):
|
||||
super(QuadraticSFunc, self).__init__(3, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x * x + self._params[1] * x + self._params[2]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class CubicSFunc(StaticFunc):
|
||||
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
|
||||
|
||||
def __init__(self, params=None, xstr="x"):
|
||||
super(CubicSFunc, self).__init__(4, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 3
|
||||
+ self._params[1] * x ** 2
|
||||
+ self._params[2] * x
|
||||
+ self._params[3]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class QuarticSFunc(StaticFunc):
|
||||
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
|
||||
|
||||
def __init__(self, params=None, xstr="x"):
|
||||
super(QuarticSFunc, self).__init__(5, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 4
|
||||
+ self._params[1] * x ** 3
|
||||
+ self._params[2] * x ** 2
|
||||
+ self._params[3] * x
|
||||
+ self._params[4]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return (
|
||||
weights[0] * x ** 4
|
||||
+ weights[1] * x ** 3
|
||||
+ weights[2] * x ** 2
|
||||
+ weights[3] * x
|
||||
+ weights[4]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
e=self._params[3],
|
||||
x=self.xstr,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
### advanced functions
|
||||
|
||||
|
||||
class ConstantFunc(StaticFunc):
|
||||
"""The constant function: f(x) = c."""
|
||||
|
||||
def __init__(self, constant, xstr="x"):
|
||||
super(ConstantFunc, self).__init__(1, {0: constant}, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0]
|
||||
|
||||
def fit(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
|
||||
|
||||
|
||||
class ComposedSinSFunc(StaticFunc):
|
||||
"""The composed sin function that outputs:
|
||||
f(x) = a * sin( b*x ) + c
|
||||
"""
|
||||
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(ComposedSinSFunc, self).__init__(3, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
a = self._params[0]
|
||||
b = self._params[1]
|
||||
c = self._params[2]
|
||||
return a * math.sin(b * x) + c
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class ComposedCosSFunc(StaticFunc):
|
||||
"""The composed sin function that outputs:
|
||||
f(x) = a * cos( b*x ) + c
|
||||
"""
|
||||
|
||||
def __init__(self, params, xstr="x"):
|
||||
super(ComposedCosSFunc, self).__init__(3, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
a = self._params[0]
|
||||
b = self._params[1]
|
||||
c = self._params[2]
|
||||
return a * math.cos(b * x) + c
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
@ -1,13 +1,13 @@
|
||||
import math
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_core import LinearFunc
|
||||
from .math_core import DynamicLinearFunc
|
||||
from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc
|
||||
from .math_core import LinearSFunc
|
||||
from .math_core import LinearDFunc
|
||||
from .math_core import QuadraticDFunc, SinQuadraticDFunc
|
||||
from .math_core import (
|
||||
ConstantFunc,
|
||||
ComposedSinFunc as SinFunc,
|
||||
ComposedCosFunc as CosFunc,
|
||||
ComposedSinSFunc as SinFunc,
|
||||
ComposedCosSFunc as CosFunc,
|
||||
)
|
||||
from .math_core import GaussianDGenerator
|
||||
|
||||
@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||
|
||||
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
|
||||
max_time = math.pi * 10
|
||||
if version == "v1":
|
||||
if version.lower() == "v1":
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
data_generator = GaussianDGenerator(
|
||||
@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
||||
time_generator = TimeStamp(
|
||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||
)
|
||||
oracle_map = DynamicLinearFunc(
|
||||
oracle_map = LinearDFunc(
|
||||
params={
|
||||
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
|
||||
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
|
||||
@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
||||
dynamic_env = SyntheticDEnv(
|
||||
data_generator, oracle_map, time_generator, num_per_task
|
||||
)
|
||||
elif version == "v2":
|
||||
dynamic_env.set_regression()
|
||||
elif version.lower() == "v2":
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
data_generator = GaussianDGenerator(
|
||||
@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
||||
time_generator = TimeStamp(
|
||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||
)
|
||||
oracle_map = DynamicQuadraticFunc(
|
||||
oracle_map = QuadraticDFunc(
|
||||
params={
|
||||
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
|
||||
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
|
||||
2: ConstantFunc(0),
|
||||
0: LinearSFunc(params={0: 0.1, 1: 0}), # 0.1 * t
|
||||
1: ConstantFunc(0),
|
||||
2: CosFunc(params={0: 4.0, 1: 10, 2: 0}), # 4 * cos(10 * t)
|
||||
}
|
||||
)
|
||||
dynamic_env = SyntheticDEnv(
|
||||
data_generator, oracle_map, time_generator, num_per_task
|
||||
)
|
||||
dynamic_env.set_regression()
|
||||
elif version.lower() == "v3":
|
||||
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
|
||||
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
|
||||
@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
||||
time_generator = TimeStamp(
|
||||
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
|
||||
)
|
||||
oracle_map = DynamicSinQuadraticFunc(
|
||||
oracle_map = SinQuadraticDFunc(
|
||||
params={
|
||||
0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1
|
||||
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
|
||||
@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
|
||||
dynamic_env = SyntheticDEnv(
|
||||
data_generator, oracle_map, time_generator, num_per_task
|
||||
)
|
||||
dynamic_env.set_regression()
|
||||
elif version.lower() == "v4":
|
||||
dynamic_env.set_classification(2)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
return dynamic_env
|
||||
|
@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset):
|
||||
self._meta_info["task"] = "classification"
|
||||
self._meta_info["num_classes"] = int(num_classes)
|
||||
|
||||
@property
|
||||
def oracle_map(self):
|
||||
return self._oracle_map
|
||||
|
||||
@property
|
||||
def meta_info(self):
|
||||
return self._meta_info
|
||||
|
Loading…
Reference in New Issue
Block a user