From b81ef2dd74496a097a92926ce35a8942b5e89957 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 14 May 2021 00:36:37 +0800 Subject: [PATCH] Complete LFNA 1.0 --- exps/LFNA/basic-prev.py | 190 ++++++++++++++++++++++++++++++++++ exps/LFNA/lfna.py | 70 +++++++++++-- exps/LFNA/lfna_meta_model.py | 64 ++++++++++-- lib/datasets/synthetic_env.py | 6 +- 4 files changed, 311 insertions(+), 19 deletions(-) create mode 100644 exps/LFNA/basic-prev.py diff --git a/exps/LFNA/basic-prev.py b/exps/LFNA/basic-prev.py new file mode 100644 index 0000000..a1dc1c5 --- /dev/null +++ b/exps/LFNA/basic-prev.py @@ -0,0 +1,190 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 +# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model + +from lfna_utils import lfna_setup + + +def subsample(historical_x, historical_y, maxn=10000): + total = historical_x.size(0) + if total <= maxn: + return historical_x, historical_y + else: + indexes = torch.randint(low=0, high=total, size=[maxn]) + return historical_x[indexes], historical_y[indexes] + + +def main(args): + logger, env_info, model_kwargs = lfna_setup(args) + + w_container_per_epoch = dict() + + per_timestamp_time, start_time = AverageMeter(), time.time() + for idx in range(1, env_info["total"]): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) + ) + logger.log( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, env_info["total"]) + + " " + + need_time + ) + # train the same data + historical_x = env_info["{:}-x".format(idx - 1)] + historical_y = env_info["{:}-y".format(idx - 1)] + # build model + model = get_model(**model_kwargs) + print(model) + # build optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) + criterion = torch.nn.MSELoss() + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(args.epochs * 0.25), + int(args.epochs * 0.5), + int(args.epochs * 0.75), + ], + gamma=0.3, + ) + train_metric = MSEMetric() + best_loss, best_param = None, None + for _iepoch in range(args.epochs): + preds = model(historical_x) + optimizer.zero_grad() + loss = criterion(preds, historical_y) + loss.backward() + optimizer.step() + lr_scheduler.step() + # save best + if best_loss is None or best_loss > loss.item(): + best_loss = loss.item() + best_param = copy.deepcopy(model.state_dict()) + model.load_state_dict(best_param) + model.analyze_weights() + with torch.no_grad(): + train_metric(preds, historical_y) + train_results = train_metric.get_info() + + metric = ComposeMetric(MSEMetric(), SaveMetric()) + eval_dataset = torch.utils.data.TensorDataset( + env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] + ) + eval_loader = torch.utils.data.DataLoader( + eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 + ) + results = basic_eval_fn(eval_loader, model, metric, logger) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, env_info["total"]) + + " train-mse: {:.5f}, eval-mse: {:.5f}".format( + train_results["mse"], results["mse"] + ) + ) + logger.log(log_str) + + save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( + idx, env_info["total"] + ) + w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() + save_checkpoint( + { + "model_state_dict": model.state_dict(), + "model": model, + "index": idx, + "timestamp": env_info["{:}-timestamp".format(idx)], + }, + save_path, + logger, + ) + logger.log("") + per_timestamp_time.update(time.time() - start_time) + start_time = time.time() + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the last timestamp.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/use-prev-timestamp", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="The batch size", + ) + parser.add_argument( + "--epochs", + type=int, + default=300, + help="The total number of epochs.", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) + main(args) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 3bb2c36..422af1a 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -1,6 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### +# python exps/LFNA/lfna.py --env_version v1 --workers 0 # python exps/LFNA/lfna.py --env_version v1 --device cuda ##################################################### import sys, time, copy, torch, random, argparse @@ -156,19 +157,61 @@ def main(args): per_epoch_time.update(time.time() - start_time) start_time = time.time() + # meta-training + meta_model.load_best() + eval_env = env_info["dynamic_env"] w_container_per_epoch = dict() - for idx in range(0, total_bar): + for idx in range(args.seq_length, env_info["total"]): + # build-timestamp future_time = env_info["{:}-timestamp".format(idx)] - future_x = env_info["{:}-x".format(idx)] - future_y = env_info["{:}-y".format(idx)] - future_container = hypernet(task_embeds[idx]) - w_container_per_epoch[idx] = future_container.no_grad_clone() + time_seqs = [] + for iseq in range(args.seq_length): + time_seqs.append(future_time - iseq * eval_env.timestamp_interval) + time_seqs.reverse() with torch.no_grad(): - future_y_hat = model.forward_with_container( + meta_model.eval() + base_model.eval() + time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) + [seq_containers] = meta_model(time_seqs) + future_container = seq_containers[-1] + w_container_per_epoch[idx] = future_container.no_grad_clone() + # evaluation + future_x = env_info["{:}-x".format(idx)] + future_y = env_info["{:}-y".format(idx)] + future_y_hat = base_model.forward_with_container( future_x, w_container_per_epoch[idx] ) future_loss = criterion(future_y_hat, future_y) - logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) + logger.log( + "meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) + ) + + # creating the new meta-time-embedding + distance = meta_model.get_closest_meta_distance(future_time) + if distance < eval_env.timestamp_interval: + continue + # + new_param = meta_model.create_meta_embed() + optimizer = torch.optim.Adam( + [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True + ) + meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param) + meta_model.eval() + base_model.train() + for iepoch in range(args.epochs): + optimizer.zero_grad() + [seq_containers] = meta_model(time_seqs) + future_container = seq_containers[-1] + future_y_hat = base_model.forward_with_container(future_x, future_container) + future_loss = criterion(future_y_hat, future_y) + future_loss.backward() + optimizer.step() + logger.log( + "post-meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item()) + ) + with torch.no_grad(): + meta_model.replace_append_learnt(None, None) + meta_model.append_fixed(torch.Tensor([future_time]), new_param) save_checkpoint( {"w_container_per_epoch": w_container_per_epoch}, @@ -216,7 +259,7 @@ if __name__ == "__main__": parser.add_argument( "--init_lr", type=float, - default=0.01, + default=0.005, help="The initial learning rate for the optimizer (default is Adam)", ) parser.add_argument( @@ -235,7 +278,7 @@ if __name__ == "__main__": parser.add_argument( "--early_stop_thresh", type=int, - default=50, + default=25, help="The maximum epochs for early stop.", ) parser.add_argument( @@ -256,7 +299,12 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format( - args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim + args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format( + args.save_dir, + args.env_version, + args.hidden_dim, + args.layer_dim, + args.time_dim, + args.epochs, ) main(args) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py index 91649b8..c25e01e 100644 --- a/exps/LFNA/lfna_meta_model.py +++ b/exps/LFNA/lfna_meta_model.py @@ -17,7 +17,7 @@ class LFNA_Meta(super_core.SuperModule): def __init__( self, shape_container, - layer_embeding, + layer_embedding, time_embedding, meta_timestamps, mha_depth: int = 2, @@ -33,13 +33,16 @@ class LFNA_Meta(super_core.SuperModule): self.register_parameter( "_super_layer_embed", - torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), + torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)), ) self.register_parameter( "_super_meta_embed", torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), ) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) + self._time_embed_dim = time_embedding + self._append_meta_embed = dict(fixed=None, learnt=None) + self._append_meta_timestamps = dict(fixed=None, learnt=None) # build transformer layers = [] @@ -60,9 +63,9 @@ class LFNA_Meta(super_core.SuperModule): model_kwargs = dict( config=dict(model_type="dual_norm_mlp"), - input_dim=layer_embeding + time_embedding, + input_dim=layer_embedding + time_embedding, output_dim=max(self._numel_per_layer), - hidden_dims=[(layer_embeding + time_embedding) * 2] * 3, + hidden_dims=[(layer_embedding + time_embedding) * 2] * 3, act_cls="gelu", norm_cls="layer_norm_1d", dropout=dropout, @@ -82,21 +85,68 @@ class LFNA_Meta(super_core.SuperModule): std=0.02, ) + @property + def meta_timestamps(self): + meta_timestamps = [self._meta_timestamps] + for key in ("fixed", "learnt"): + if self._append_meta_timestamps[key] is not None: + meta_timestamps.append(self._append_meta_timestamps[key]) + return torch.cat(meta_timestamps) + + @property + def super_meta_embed(self): + meta_embed = [self._super_meta_embed] + for key in ("fixed", "learnt"): + if self._append_meta_embed[key] is not None: + meta_embed.append(self._append_meta_embed[key]) + return torch.cat(meta_embed) + + def create_meta_embed(self): + param = torch.nn.Parameter(torch.Tensor(1, self._time_embed_dim)) + trunc_normal_(param, std=0.02) + return param.to(self._super_meta_embed.device) + + def get_closest_meta_distance(self, timestamp): + with torch.no_grad(): + distances = torch.abs(self.meta_timestamps - timestamp) + return torch.min(distances).item() + + def replace_append_learnt(self, timestamp, meta_embed): + self._append_meta_embed["learnt"] = meta_embed + self._append_meta_timestamps["learnt"] = timestamp + + def append_fixed(self, timestamp, meta_embed): + with torch.no_grad(): + timestamp, meta_embed = timestamp.clone(), meta_embed.clone() + if self._append_meta_timestamps["fixed"] is None: + self._append_meta_timestamps["fixed"] = timestamp + else: + self._append_meta_timestamps["fixed"] = torch.cat( + (self._append_meta_timestamps["fixed"], timestamp), dim=0 + ) + if self._append_meta_embed["fixed"] is None: + self._append_meta_embed["fixed"] = meta_embed + else: + self._append_meta_embed["fixed"] = torch.cat( + (self._append_meta_embed["fixed"], meta_embed), dim=0 + ) + def forward_raw(self, timestamps): # timestamps is a batch of sequence of timestamps batch, seq = timestamps.shape timestamps = timestamps.unsqueeze(dim=-1) - meta_timestamps = self._meta_timestamps.view(1, 1, -1) + meta_timestamps = self.meta_timestamps.view(1, 1, -1) time_diffs = timestamps - meta_timestamps time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) # select corresponding meta-knowledge meta_match = torch.index_select( - self._super_meta_embed, dim=0, index=time_match_i.view(-1) + self.super_meta_embed, dim=0, index=time_match_i.view(-1) ) meta_match = meta_match.view(batch, seq, -1) # create the probability time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) - time_probs[:, -1, :] = 0 + if self.training: + time_probs[:, -1, :] = 0 unknown_token = self._unknown_token.view(1, 1, -1) raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index 767f817..d4fcf24 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -43,6 +43,7 @@ class SyntheticDEnv(data.Dataset): num_per_task: int = 5000, timestamp_config: Optional[Dict] = None, mode: Optional[str] = None, + timestamp_noise_scale: float = 0.3, ): self._ndim = len(mean_functors) assert self._ndim == len( @@ -59,6 +60,7 @@ class SyntheticDEnv(data.Dataset): timestamp_config["mode"] = mode self._timestamp_generator = TimeStamp(**timestamp_config) + self._timestamp_noise_scale = timestamp_noise_scale self._mean_functors = mean_functors self._cov_functors = cov_functors @@ -110,7 +112,9 @@ class SyntheticDEnv(data.Dataset): if self._seq_length is None: return self.__call__(timestamp) else: - noise = random.random() * self.timestamp_interval * 0.3 + noise = ( + random.random() * self.timestamp_interval * self._timestamp_noise_scale + ) timestamps = [ timestamp + i * self.timestamp_interval + noise for i in range(self._seq_length)