Complete LFNA 1.0
This commit is contained in:
parent
c2fa181bc5
commit
b81ef2dd74
190
exps/LFNA/basic-prev.py
Normal file
190
exps/LFNA/basic-prev.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
|
#####################################################
|
||||||
|
# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1
|
||||||
|
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
|
||||||
|
#####################################################
|
||||||
|
import sys, time, copy, torch, random, argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||||
|
from log_utils import time_string
|
||||||
|
from log_utils import AverageMeter, convert_secs2time
|
||||||
|
|
||||||
|
from utils import split_str2indexes
|
||||||
|
|
||||||
|
from procedures.advanced_main import basic_train_fn, basic_eval_fn
|
||||||
|
from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||||
|
from datasets.synthetic_core import get_synthetic_env
|
||||||
|
from models.xcore import get_model
|
||||||
|
|
||||||
|
from lfna_utils import lfna_setup
|
||||||
|
|
||||||
|
|
||||||
|
def subsample(historical_x, historical_y, maxn=10000):
|
||||||
|
total = historical_x.size(0)
|
||||||
|
if total <= maxn:
|
||||||
|
return historical_x, historical_y
|
||||||
|
else:
|
||||||
|
indexes = torch.randint(low=0, high=total, size=[maxn])
|
||||||
|
return historical_x[indexes], historical_y[indexes]
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
logger, env_info, model_kwargs = lfna_setup(args)
|
||||||
|
|
||||||
|
w_container_per_epoch = dict()
|
||||||
|
|
||||||
|
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||||
|
for idx in range(1, env_info["total"]):
|
||||||
|
|
||||||
|
need_time = "Time Left: {:}".format(
|
||||||
|
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
|
||||||
|
)
|
||||||
|
logger.log(
|
||||||
|
"[{:}]".format(time_string())
|
||||||
|
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
|
||||||
|
+ " "
|
||||||
|
+ need_time
|
||||||
|
)
|
||||||
|
# train the same data
|
||||||
|
historical_x = env_info["{:}-x".format(idx - 1)]
|
||||||
|
historical_y = env_info["{:}-y".format(idx - 1)]
|
||||||
|
# build model
|
||||||
|
model = get_model(**model_kwargs)
|
||||||
|
print(model)
|
||||||
|
# build optimizer
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=[
|
||||||
|
int(args.epochs * 0.25),
|
||||||
|
int(args.epochs * 0.5),
|
||||||
|
int(args.epochs * 0.75),
|
||||||
|
],
|
||||||
|
gamma=0.3,
|
||||||
|
)
|
||||||
|
train_metric = MSEMetric()
|
||||||
|
best_loss, best_param = None, None
|
||||||
|
for _iepoch in range(args.epochs):
|
||||||
|
preds = model(historical_x)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(preds, historical_y)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
# save best
|
||||||
|
if best_loss is None or best_loss > loss.item():
|
||||||
|
best_loss = loss.item()
|
||||||
|
best_param = copy.deepcopy(model.state_dict())
|
||||||
|
model.load_state_dict(best_param)
|
||||||
|
model.analyze_weights()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_metric(preds, historical_y)
|
||||||
|
train_results = train_metric.get_info()
|
||||||
|
|
||||||
|
metric = ComposeMetric(MSEMetric(), SaveMetric())
|
||||||
|
eval_dataset = torch.utils.data.TensorDataset(
|
||||||
|
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
|
||||||
|
)
|
||||||
|
eval_loader = torch.utils.data.DataLoader(
|
||||||
|
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
|
||||||
|
)
|
||||||
|
results = basic_eval_fn(eval_loader, model, metric, logger)
|
||||||
|
log_str = (
|
||||||
|
"[{:}]".format(time_string())
|
||||||
|
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
|
||||||
|
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
|
||||||
|
train_results["mse"], results["mse"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.log(log_str)
|
||||||
|
|
||||||
|
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
|
||||||
|
idx, env_info["total"]
|
||||||
|
)
|
||||||
|
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
|
||||||
|
save_checkpoint(
|
||||||
|
{
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"model": model,
|
||||||
|
"index": idx,
|
||||||
|
"timestamp": env_info["{:}-timestamp".format(idx)],
|
||||||
|
},
|
||||||
|
save_path,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
logger.log("")
|
||||||
|
per_timestamp_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
{"w_container_per_epoch": w_container_per_epoch},
|
||||||
|
logger.path(None) / "final-ckp.pth",
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.log("-" * 200 + "\n")
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Use the data in the last timestamp.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_dir",
|
||||||
|
type=str,
|
||||||
|
default="./outputs/lfna-synthetic/use-prev-timestamp",
|
||||||
|
help="The checkpoint directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env_version",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The synthetic enviornment version.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden_dim",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="The hidden dimension.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--init_lr",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The initial learning rate for the optimizer (default is Adam)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="The batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="The total number of epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="The number of data loading workers (default: 4)",
|
||||||
|
)
|
||||||
|
# Random Seed
|
||||||
|
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
|
args.rand_seed = random.randint(1, 100000)
|
||||||
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
|
args.save_dir = "{:}-{:}-d{:}".format(
|
||||||
|
args.save_dir, args.env_version, args.hidden_dim
|
||||||
|
)
|
||||||
|
main(args)
|
@ -1,6 +1,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
# python exps/LFNA/lfna.py --env_version v1 --workers 0
|
||||||
# python exps/LFNA/lfna.py --env_version v1 --device cuda
|
# python exps/LFNA/lfna.py --env_version v1 --device cuda
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, copy, torch, random, argparse
|
import sys, time, copy, torch, random, argparse
|
||||||
@ -156,19 +157,61 @@ def main(args):
|
|||||||
per_epoch_time.update(time.time() - start_time)
|
per_epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# meta-training
|
||||||
|
meta_model.load_best()
|
||||||
|
eval_env = env_info["dynamic_env"]
|
||||||
w_container_per_epoch = dict()
|
w_container_per_epoch = dict()
|
||||||
for idx in range(0, total_bar):
|
for idx in range(args.seq_length, env_info["total"]):
|
||||||
|
# build-timestamp
|
||||||
future_time = env_info["{:}-timestamp".format(idx)]
|
future_time = env_info["{:}-timestamp".format(idx)]
|
||||||
future_x = env_info["{:}-x".format(idx)]
|
time_seqs = []
|
||||||
future_y = env_info["{:}-y".format(idx)]
|
for iseq in range(args.seq_length):
|
||||||
future_container = hypernet(task_embeds[idx])
|
time_seqs.append(future_time - iseq * eval_env.timestamp_interval)
|
||||||
w_container_per_epoch[idx] = future_container.no_grad_clone()
|
time_seqs.reverse()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
future_y_hat = model.forward_with_container(
|
meta_model.eval()
|
||||||
|
base_model.eval()
|
||||||
|
time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device)
|
||||||
|
[seq_containers] = meta_model(time_seqs)
|
||||||
|
future_container = seq_containers[-1]
|
||||||
|
w_container_per_epoch[idx] = future_container.no_grad_clone()
|
||||||
|
# evaluation
|
||||||
|
future_x = env_info["{:}-x".format(idx)]
|
||||||
|
future_y = env_info["{:}-y".format(idx)]
|
||||||
|
future_y_hat = base_model.forward_with_container(
|
||||||
future_x, w_container_per_epoch[idx]
|
future_x, w_container_per_epoch[idx]
|
||||||
)
|
)
|
||||||
future_loss = criterion(future_y_hat, future_y)
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()))
|
logger.log(
|
||||||
|
"meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
|
||||||
|
)
|
||||||
|
|
||||||
|
# creating the new meta-time-embedding
|
||||||
|
distance = meta_model.get_closest_meta_distance(future_time)
|
||||||
|
if distance < eval_env.timestamp_interval:
|
||||||
|
continue
|
||||||
|
#
|
||||||
|
new_param = meta_model.create_meta_embed()
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
||||||
|
)
|
||||||
|
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param)
|
||||||
|
meta_model.eval()
|
||||||
|
base_model.train()
|
||||||
|
for iepoch in range(args.epochs):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
[seq_containers] = meta_model(time_seqs)
|
||||||
|
future_container = seq_containers[-1]
|
||||||
|
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||||
|
future_loss = criterion(future_y_hat, future_y)
|
||||||
|
future_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
logger.log(
|
||||||
|
"post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
meta_model.replace_append_learnt(None, None)
|
||||||
|
meta_model.append_fixed(torch.Tensor([future_time]), new_param)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
{"w_container_per_epoch": w_container_per_epoch},
|
{"w_container_per_epoch": w_container_per_epoch},
|
||||||
@ -216,7 +259,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init_lr",
|
"--init_lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.01,
|
default=0.005,
|
||||||
help="The initial learning rate for the optimizer (default is Adam)",
|
help="The initial learning rate for the optimizer (default is Adam)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -235,7 +278,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_thresh",
|
"--early_stop_thresh",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=25,
|
||||||
help="The maximum epochs for early stop.",
|
help="The maximum epochs for early stop.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -256,7 +299,12 @@ if __name__ == "__main__":
|
|||||||
if args.rand_seed is None or args.rand_seed < 0:
|
if args.rand_seed is None or args.rand_seed < 0:
|
||||||
args.rand_seed = random.randint(1, 100000)
|
args.rand_seed = random.randint(1, 100000)
|
||||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||||
args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format(
|
args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format(
|
||||||
args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim
|
args.save_dir,
|
||||||
|
args.env_version,
|
||||||
|
args.hidden_dim,
|
||||||
|
args.layer_dim,
|
||||||
|
args.time_dim,
|
||||||
|
args.epochs,
|
||||||
)
|
)
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -17,7 +17,7 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shape_container,
|
shape_container,
|
||||||
layer_embeding,
|
layer_embedding,
|
||||||
time_embedding,
|
time_embedding,
|
||||||
meta_timestamps,
|
meta_timestamps,
|
||||||
mha_depth: int = 2,
|
mha_depth: int = 2,
|
||||||
@ -33,13 +33,16 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
|
|
||||||
self.register_parameter(
|
self.register_parameter(
|
||||||
"_super_layer_embed",
|
"_super_layer_embed",
|
||||||
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
|
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)),
|
||||||
)
|
)
|
||||||
self.register_parameter(
|
self.register_parameter(
|
||||||
"_super_meta_embed",
|
"_super_meta_embed",
|
||||||
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
|
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)),
|
||||||
)
|
)
|
||||||
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
|
||||||
|
self._time_embed_dim = time_embedding
|
||||||
|
self._append_meta_embed = dict(fixed=None, learnt=None)
|
||||||
|
self._append_meta_timestamps = dict(fixed=None, learnt=None)
|
||||||
|
|
||||||
# build transformer
|
# build transformer
|
||||||
layers = []
|
layers = []
|
||||||
@ -60,9 +63,9 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
config=dict(model_type="dual_norm_mlp"),
|
config=dict(model_type="dual_norm_mlp"),
|
||||||
input_dim=layer_embeding + time_embedding,
|
input_dim=layer_embedding + time_embedding,
|
||||||
output_dim=max(self._numel_per_layer),
|
output_dim=max(self._numel_per_layer),
|
||||||
hidden_dims=[(layer_embeding + time_embedding) * 2] * 3,
|
hidden_dims=[(layer_embedding + time_embedding) * 2] * 3,
|
||||||
act_cls="gelu",
|
act_cls="gelu",
|
||||||
norm_cls="layer_norm_1d",
|
norm_cls="layer_norm_1d",
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -82,21 +85,68 @@ class LFNA_Meta(super_core.SuperModule):
|
|||||||
std=0.02,
|
std=0.02,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_timestamps(self):
|
||||||
|
meta_timestamps = [self._meta_timestamps]
|
||||||
|
for key in ("fixed", "learnt"):
|
||||||
|
if self._append_meta_timestamps[key] is not None:
|
||||||
|
meta_timestamps.append(self._append_meta_timestamps[key])
|
||||||
|
return torch.cat(meta_timestamps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def super_meta_embed(self):
|
||||||
|
meta_embed = [self._super_meta_embed]
|
||||||
|
for key in ("fixed", "learnt"):
|
||||||
|
if self._append_meta_embed[key] is not None:
|
||||||
|
meta_embed.append(self._append_meta_embed[key])
|
||||||
|
return torch.cat(meta_embed)
|
||||||
|
|
||||||
|
def create_meta_embed(self):
|
||||||
|
param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim))
|
||||||
|
trunc_normal_(param, std=0.02)
|
||||||
|
return param.to(self._super_meta_embed.device)
|
||||||
|
|
||||||
|
def get_closest_meta_distance(self, timestamp):
|
||||||
|
with torch.no_grad():
|
||||||
|
distances = torch.abs(self.meta_timestamps - timestamp)
|
||||||
|
return torch.min(distances).item()
|
||||||
|
|
||||||
|
def replace_append_learnt(self, timestamp, meta_embed):
|
||||||
|
self._append_meta_embed["learnt"] = meta_embed
|
||||||
|
self._append_meta_timestamps["learnt"] = timestamp
|
||||||
|
|
||||||
|
def append_fixed(self, timestamp, meta_embed):
|
||||||
|
with torch.no_grad():
|
||||||
|
timestamp, meta_embed = timestamp.clone(), meta_embed.clone()
|
||||||
|
if self._append_meta_timestamps["fixed"] is None:
|
||||||
|
self._append_meta_timestamps["fixed"] = timestamp
|
||||||
|
else:
|
||||||
|
self._append_meta_timestamps["fixed"] = torch.cat(
|
||||||
|
(self._append_meta_timestamps["fixed"], timestamp), dim=0
|
||||||
|
)
|
||||||
|
if self._append_meta_embed["fixed"] is None:
|
||||||
|
self._append_meta_embed["fixed"] = meta_embed
|
||||||
|
else:
|
||||||
|
self._append_meta_embed["fixed"] = torch.cat(
|
||||||
|
(self._append_meta_embed["fixed"], meta_embed), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward_raw(self, timestamps):
|
def forward_raw(self, timestamps):
|
||||||
# timestamps is a batch of sequence of timestamps
|
# timestamps is a batch of sequence of timestamps
|
||||||
batch, seq = timestamps.shape
|
batch, seq = timestamps.shape
|
||||||
timestamps = timestamps.unsqueeze(dim=-1)
|
timestamps = timestamps.unsqueeze(dim=-1)
|
||||||
meta_timestamps = self._meta_timestamps.view(1, 1, -1)
|
meta_timestamps = self.meta_timestamps.view(1, 1, -1)
|
||||||
time_diffs = timestamps - meta_timestamps
|
time_diffs = timestamps - meta_timestamps
|
||||||
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
|
time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1)
|
||||||
# select corresponding meta-knowledge
|
# select corresponding meta-knowledge
|
||||||
meta_match = torch.index_select(
|
meta_match = torch.index_select(
|
||||||
self._super_meta_embed, dim=0, index=time_match_i.view(-1)
|
self.super_meta_embed, dim=0, index=time_match_i.view(-1)
|
||||||
)
|
)
|
||||||
meta_match = meta_match.view(batch, seq, -1)
|
meta_match = meta_match.view(batch, seq, -1)
|
||||||
# create the probability
|
# create the probability
|
||||||
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
|
time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1)
|
||||||
time_probs[:, -1, :] = 0
|
if self.training:
|
||||||
|
time_probs[:, -1, :] = 0
|
||||||
unknown_token = self._unknown_token.view(1, 1, -1)
|
unknown_token = self._unknown_token.view(1, 1, -1)
|
||||||
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token
|
raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
num_per_task: int = 5000,
|
num_per_task: int = 5000,
|
||||||
timestamp_config: Optional[Dict] = None,
|
timestamp_config: Optional[Dict] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
|
timestamp_noise_scale: float = 0.3,
|
||||||
):
|
):
|
||||||
self._ndim = len(mean_functors)
|
self._ndim = len(mean_functors)
|
||||||
assert self._ndim == len(
|
assert self._ndim == len(
|
||||||
@ -59,6 +60,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
timestamp_config["mode"] = mode
|
timestamp_config["mode"] = mode
|
||||||
|
|
||||||
self._timestamp_generator = TimeStamp(**timestamp_config)
|
self._timestamp_generator = TimeStamp(**timestamp_config)
|
||||||
|
self._timestamp_noise_scale = timestamp_noise_scale
|
||||||
|
|
||||||
self._mean_functors = mean_functors
|
self._mean_functors = mean_functors
|
||||||
self._cov_functors = cov_functors
|
self._cov_functors = cov_functors
|
||||||
@ -110,7 +112,9 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
if self._seq_length is None:
|
if self._seq_length is None:
|
||||||
return self.__call__(timestamp)
|
return self.__call__(timestamp)
|
||||||
else:
|
else:
|
||||||
noise = random.random() * self.timestamp_interval * 0.3
|
noise = (
|
||||||
|
random.random() * self.timestamp_interval * self._timestamp_noise_scale
|
||||||
|
)
|
||||||
timestamps = [
|
timestamps = [
|
||||||
timestamp + i * self.timestamp_interval + noise
|
timestamp + i * self.timestamp_interval + noise
|
||||||
for i in range(self._seq_length)
|
for i in range(self._seq_length)
|
||||||
|
Loading…
Reference in New Issue
Block a user