Complete LFNA 1.0
This commit is contained in:
		
							
								
								
									
										190
									
								
								exps/LFNA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								exps/LFNA/basic-prev.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,190 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 | ||||||
|  | # python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  |  | ||||||
|  | from lfna_utils import lfna_setup | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def subsample(historical_x, historical_y, maxn=10000): | ||||||
|  |     total = historical_x.size(0) | ||||||
|  |     if total <= maxn: | ||||||
|  |         return historical_x, historical_y | ||||||
|  |     else: | ||||||
|  |         indexes = torch.randint(low=0, high=total, size=[maxn]) | ||||||
|  |         return historical_x[indexes], historical_y[indexes] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|  |  | ||||||
|  |     w_container_per_epoch = dict() | ||||||
|  |  | ||||||
|  |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for idx in range(1, env_info["total"]): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|  |             + " " | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |         # train the same data | ||||||
|  |         historical_x = env_info["{:}-x".format(idx - 1)] | ||||||
|  |         historical_y = env_info["{:}-y".format(idx - 1)] | ||||||
|  |         # build model | ||||||
|  |         model = get_model(**model_kwargs) | ||||||
|  |         print(model) | ||||||
|  |         # build optimizer | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(args.epochs * 0.25), | ||||||
|  |                 int(args.epochs * 0.5), | ||||||
|  |                 int(args.epochs * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.3, | ||||||
|  |         ) | ||||||
|  |         train_metric = MSEMetric() | ||||||
|  |         best_loss, best_param = None, None | ||||||
|  |         for _iepoch in range(args.epochs): | ||||||
|  |             preds = model(historical_x) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss = criterion(preds, historical_y) | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             lr_scheduler.step() | ||||||
|  |             # save best | ||||||
|  |             if best_loss is None or best_loss > loss.item(): | ||||||
|  |                 best_loss = loss.item() | ||||||
|  |                 best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |         model.load_state_dict(best_param) | ||||||
|  |         model.analyze_weights() | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             train_metric(preds, historical_y) | ||||||
|  |         train_results = train_metric.get_info() | ||||||
|  |  | ||||||
|  |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|  |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|  |             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||||
|  |         ) | ||||||
|  |         eval_loader = torch.utils.data.DataLoader( | ||||||
|  |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
|  |         ) | ||||||
|  |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|  |         log_str = ( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|  |             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||||
|  |                 train_results["mse"], results["mse"] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log(log_str) | ||||||
|  |  | ||||||
|  |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|  |             idx, env_info["total"] | ||||||
|  |         ) | ||||||
|  |         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||||
|  |         save_checkpoint( | ||||||
|  |             { | ||||||
|  |                 "model_state_dict": model.state_dict(), | ||||||
|  |                 "model": model, | ||||||
|  |                 "index": idx, | ||||||
|  |                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||||
|  |             }, | ||||||
|  |             save_path, | ||||||
|  |             logger, | ||||||
|  |         ) | ||||||
|  |         logger.log("") | ||||||
|  |         per_timestamp_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     save_checkpoint( | ||||||
|  |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use the data in the last timestamp.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/use-prev-timestamp", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--env_version", | ||||||
|  |         type=str, | ||||||
|  |         required=True, | ||||||
|  |         help="The synthetic enviornment version.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=300, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     args.save_dir = "{:}-{:}-d{:}".format( | ||||||
|  |         args.save_dir, args.env_version, args.hidden_dim | ||||||
|  |     ) | ||||||
|  |     main(args) | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | # python exps/LFNA/lfna.py --env_version v1 --workers 0 | ||||||
| # python exps/LFNA/lfna.py --env_version v1 --device cuda | # python exps/LFNA/lfna.py --env_version v1 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| @@ -156,19 +157,61 @@ def main(args): | |||||||
|         per_epoch_time.update(time.time() - start_time) |         per_epoch_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     # meta-training | ||||||
|  |     meta_model.load_best() | ||||||
|  |     eval_env = env_info["dynamic_env"] | ||||||
|     w_container_per_epoch = dict() |     w_container_per_epoch = dict() | ||||||
|     for idx in range(0, total_bar): |     for idx in range(args.seq_length, env_info["total"]): | ||||||
|  |         # build-timestamp | ||||||
|         future_time = env_info["{:}-timestamp".format(idx)] |         future_time = env_info["{:}-timestamp".format(idx)] | ||||||
|         future_x = env_info["{:}-x".format(idx)] |         time_seqs = [] | ||||||
|         future_y = env_info["{:}-y".format(idx)] |         for iseq in range(args.seq_length): | ||||||
|         future_container = hypernet(task_embeds[idx]) |             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) | ||||||
|         w_container_per_epoch[idx] = future_container.no_grad_clone() |         time_seqs.reverse() | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             future_y_hat = model.forward_with_container( |             meta_model.eval() | ||||||
|  |             base_model.eval() | ||||||
|  |             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) | ||||||
|  |             [seq_containers] = meta_model(time_seqs) | ||||||
|  |             future_container = seq_containers[-1] | ||||||
|  |             w_container_per_epoch[idx] = future_container.no_grad_clone() | ||||||
|  |             # evaluation | ||||||
|  |             future_x = env_info["{:}-x".format(idx)] | ||||||
|  |             future_y = env_info["{:}-y".format(idx)] | ||||||
|  |             future_y_hat = base_model.forward_with_container( | ||||||
|                 future_x, w_container_per_epoch[idx] |                 future_x, w_container_per_epoch[idx] | ||||||
|             ) |             ) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|         logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) |             logger.log( | ||||||
|  |                 "meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # creating the new meta-time-embedding | ||||||
|  |         distance = meta_model.get_closest_meta_distance(future_time) | ||||||
|  |         if distance < eval_env.timestamp_interval: | ||||||
|  |             continue | ||||||
|  |         # | ||||||
|  |         new_param = meta_model.create_meta_embed() | ||||||
|  |         optimizer = torch.optim.Adam( | ||||||
|  |             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True | ||||||
|  |         ) | ||||||
|  |         meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param) | ||||||
|  |         meta_model.eval() | ||||||
|  |         base_model.train() | ||||||
|  |         for iepoch in range(args.epochs): | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             [seq_containers] = meta_model(time_seqs) | ||||||
|  |             future_container = seq_containers[-1] | ||||||
|  |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|  |             future_loss = criterion(future_y_hat, future_y) | ||||||
|  |             future_loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |         logger.log( | ||||||
|  |             "post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) | ||||||
|  |         ) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             meta_model.replace_append_learnt(None, None) | ||||||
|  |             meta_model.append_fixed(torch.Tensor([future_time]), new_param) | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"w_container_per_epoch": w_container_per_epoch}, |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
| @@ -216,7 +259,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.01, |         default=0.005, | ||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -235,7 +278,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|         type=int, |         type=int, | ||||||
|         default=50, |         default=25, | ||||||
|         help="The maximum epochs for early stop.", |         help="The maximum epochs for early stop.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @@ -256,7 +299,12 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format( |     args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format( | ||||||
|         args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim |         args.save_dir, | ||||||
|  |         args.env_version, | ||||||
|  |         args.hidden_dim, | ||||||
|  |         args.layer_dim, | ||||||
|  |         args.time_dim, | ||||||
|  |         args.epochs, | ||||||
|     ) |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         shape_container, |         shape_container, | ||||||
|         layer_embeding, |         layer_embedding, | ||||||
|         time_embedding, |         time_embedding, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         mha_depth: int = 2, |         mha_depth: int = 2, | ||||||
| @@ -33,13 +33,16 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_layer_embed", |             "_super_layer_embed", | ||||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), |             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)), | ||||||
|         ) |         ) | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_meta_embed", |             "_super_meta_embed", | ||||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), |             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), | ||||||
|         ) |         ) | ||||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) |         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||||
|  |         self._time_embed_dim = time_embedding | ||||||
|  |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|  |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
|  |  | ||||||
|         # build transformer |         # build transformer | ||||||
|         layers = [] |         layers = [] | ||||||
| @@ -60,9 +63,9 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|         model_kwargs = dict( |         model_kwargs = dict( | ||||||
|             config=dict(model_type="dual_norm_mlp"), |             config=dict(model_type="dual_norm_mlp"), | ||||||
|             input_dim=layer_embeding + time_embedding, |             input_dim=layer_embedding + time_embedding, | ||||||
|             output_dim=max(self._numel_per_layer), |             output_dim=max(self._numel_per_layer), | ||||||
|             hidden_dims=[(layer_embeding + time_embedding) * 2] * 3, |             hidden_dims=[(layer_embedding + time_embedding) * 2] * 3, | ||||||
|             act_cls="gelu", |             act_cls="gelu", | ||||||
|             norm_cls="layer_norm_1d", |             norm_cls="layer_norm_1d", | ||||||
|             dropout=dropout, |             dropout=dropout, | ||||||
| @@ -82,21 +85,68 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             std=0.02, |             std=0.02, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def meta_timestamps(self): | ||||||
|  |         meta_timestamps = [self._meta_timestamps] | ||||||
|  |         for key in ("fixed", "learnt"): | ||||||
|  |             if self._append_meta_timestamps[key] is not None: | ||||||
|  |                 meta_timestamps.append(self._append_meta_timestamps[key]) | ||||||
|  |         return torch.cat(meta_timestamps) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def super_meta_embed(self): | ||||||
|  |         meta_embed = [self._super_meta_embed] | ||||||
|  |         for key in ("fixed", "learnt"): | ||||||
|  |             if self._append_meta_embed[key] is not None: | ||||||
|  |                 meta_embed.append(self._append_meta_embed[key]) | ||||||
|  |         return torch.cat(meta_embed) | ||||||
|  |  | ||||||
|  |     def create_meta_embed(self): | ||||||
|  |         param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim)) | ||||||
|  |         trunc_normal_(param, std=0.02) | ||||||
|  |         return param.to(self._super_meta_embed.device) | ||||||
|  |  | ||||||
|  |     def get_closest_meta_distance(self, timestamp): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             distances = torch.abs(self.meta_timestamps - timestamp) | ||||||
|  |             return torch.min(distances).item() | ||||||
|  |  | ||||||
|  |     def replace_append_learnt(self, timestamp, meta_embed): | ||||||
|  |         self._append_meta_embed["learnt"] = meta_embed | ||||||
|  |         self._append_meta_timestamps["learnt"] = timestamp | ||||||
|  |  | ||||||
|  |     def append_fixed(self, timestamp, meta_embed): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             timestamp, meta_embed = timestamp.clone(), meta_embed.clone() | ||||||
|  |             if self._append_meta_timestamps["fixed"] is None: | ||||||
|  |                 self._append_meta_timestamps["fixed"] = timestamp | ||||||
|  |             else: | ||||||
|  |                 self._append_meta_timestamps["fixed"] = torch.cat( | ||||||
|  |                     (self._append_meta_timestamps["fixed"], timestamp), dim=0 | ||||||
|  |                 ) | ||||||
|  |             if self._append_meta_embed["fixed"] is None: | ||||||
|  |                 self._append_meta_embed["fixed"] = meta_embed | ||||||
|  |             else: | ||||||
|  |                 self._append_meta_embed["fixed"] = torch.cat( | ||||||
|  |                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps): |     def forward_raw(self, timestamps): | ||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         timestamps = timestamps.unsqueeze(dim=-1) |         timestamps = timestamps.unsqueeze(dim=-1) | ||||||
|         meta_timestamps = self._meta_timestamps.view(1, 1, -1) |         meta_timestamps = self.meta_timestamps.view(1, 1, -1) | ||||||
|         time_diffs = timestamps - meta_timestamps |         time_diffs = timestamps - meta_timestamps | ||||||
|         time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) |         time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) | ||||||
|         # select corresponding meta-knowledge |         # select corresponding meta-knowledge | ||||||
|         meta_match = torch.index_select( |         meta_match = torch.index_select( | ||||||
|             self._super_meta_embed, dim=0, index=time_match_i.view(-1) |             self.super_meta_embed, dim=0, index=time_match_i.view(-1) | ||||||
|         ) |         ) | ||||||
|         meta_match = meta_match.view(batch, seq, -1) |         meta_match = meta_match.view(batch, seq, -1) | ||||||
|         # create the probability |         # create the probability | ||||||
|         time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) |         time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) | ||||||
|         time_probs[:, -1, :] = 0 |         if self.training: | ||||||
|  |             time_probs[:, -1, :] = 0 | ||||||
|         unknown_token = self._unknown_token.view(1, 1, -1) |         unknown_token = self._unknown_token.view(1, 1, -1) | ||||||
|         raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token |         raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token | ||||||
|  |  | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         num_per_task: int = 5000, |         num_per_task: int = 5000, | ||||||
|         timestamp_config: Optional[Dict] = None, |         timestamp_config: Optional[Dict] = None, | ||||||
|         mode: Optional[str] = None, |         mode: Optional[str] = None, | ||||||
|  |         timestamp_noise_scale: float = 0.3, | ||||||
|     ): |     ): | ||||||
|         self._ndim = len(mean_functors) |         self._ndim = len(mean_functors) | ||||||
|         assert self._ndim == len( |         assert self._ndim == len( | ||||||
| @@ -59,6 +60,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|             timestamp_config["mode"] = mode |             timestamp_config["mode"] = mode | ||||||
|  |  | ||||||
|         self._timestamp_generator = TimeStamp(**timestamp_config) |         self._timestamp_generator = TimeStamp(**timestamp_config) | ||||||
|  |         self._timestamp_noise_scale = timestamp_noise_scale | ||||||
|  |  | ||||||
|         self._mean_functors = mean_functors |         self._mean_functors = mean_functors | ||||||
|         self._cov_functors = cov_functors |         self._cov_functors = cov_functors | ||||||
| @@ -110,7 +112,9 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         if self._seq_length is None: |         if self._seq_length is None: | ||||||
|             return self.__call__(timestamp) |             return self.__call__(timestamp) | ||||||
|         else: |         else: | ||||||
|             noise = random.random() * self.timestamp_interval * 0.3 |             noise = ( | ||||||
|  |                 random.random() * self.timestamp_interval * self._timestamp_noise_scale | ||||||
|  |             ) | ||||||
|             timestamps = [ |             timestamps = [ | ||||||
|                 timestamp + i * self.timestamp_interval + noise |                 timestamp + i * self.timestamp_interval + noise | ||||||
|                 for i in range(self._seq_length) |                 for i in range(self._seq_length) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user