From 9a5663390ad17f3d9cc2ad19400a0713814dda37 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 27 May 2021 23:04:27 +0800 Subject: [PATCH] Updates --- exps/GeMOSA/baselines/slbm-nof.py | 224 ++++++++++++++++++++++++++++++ xautodl/datasets/synthetic_env.py | 8 ++ 2 files changed, 232 insertions(+) create mode 100644 exps/GeMOSA/baselines/slbm-nof.py diff --git a/exps/GeMOSA/baselines/slbm-nof.py b/exps/GeMOSA/baselines/slbm-nof.py new file mode 100644 index 0000000..2aef4e6 --- /dev/null +++ b/exps/GeMOSA/baselines/slbm-nof.py @@ -0,0 +1,224 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/GeMOSA/baselines/slbm-nof.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda +# python exps/GeMOSA/baselines/slbm-nof.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda +# python exps/GeMOSA/baselines/slbm-nof.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda +# python exps/GeMOSA/baselines/slbm-nof.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() +print("LIB-DIR: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + + +from xautodl.procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, +) +from xautodl.log_utils import time_string +from xautodl.log_utils import AverageMeter, convert_secs2time + +from xautodl.procedures.metric_utils import ( + SaveMetric, + MSEMetric, + Top1AccMetric, + ComposeMetric, +) +from xautodl.datasets.synthetic_core import get_synthetic_env +from xautodl.models.xcore import get_model +from xautodl.utils import show_mean_var + + +def subsample(historical_x, historical_y, maxn=10000): + total = historical_x.size(0) + if total <= maxn: + return historical_x, historical_y + else: + indexes = torch.randint(low=0, high=total, size=[maxn]) + return historical_x[indexes], historical_y[indexes] + + +def main(args): + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + env = get_synthetic_env(mode="test", version=args.env_version) + model_kwargs = dict( + config=dict(model_type="norm_mlp"), + input_dim=env.meta_info["input_dim"], + output_dim=env.meta_info["output_dim"], + hidden_dims=[args.hidden_dim] * 2, + act_cls="relu", + norm_cls="layer_norm_1d", + ) + logger.log("The total enviornment: {:}".format(env)) + w_containers = dict() + + if env.meta_info["task"] == "regression": + criterion = torch.nn.MSELoss() + metric_cls = MSEMetric + elif env.meta_info["task"] == "classification": + criterion = torch.nn.CrossEntropyLoss() + metric_cls = Top1AccMetric + else: + raise ValueError( + "This task ({:}) is not supported.".format(all_env.meta_info["task"]) + ) + + seq_length = 10 + seq_times = env.get_seq_times(0, seq_length) + _, (allxs, allys) = env.seq_call(seq_times) + allxs, allys = allxs.view(-1, 1), allys.view(-1, 1) + + historical_x, historical_y = allxs.to(args.device), allys.to(args.device) + model = get_model(**model_kwargs) + model = model.to(args.device) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + int(args.epochs * 0.25), + int(args.epochs * 0.5), + int(args.epochs * 0.75), + ], + gamma=0.3, + ) + + train_metric = metric_cls(True) + best_loss, best_param = None, None + for _iepoch in range(args.epochs): + preds = model(historical_x) + optimizer.zero_grad() + loss = criterion(preds, historical_y) + loss.backward() + optimizer.step() + lr_scheduler.step() + # save best + if best_loss is None or best_loss > loss.item(): + best_loss = loss.item() + best_param = copy.deepcopy(model.state_dict()) + model.load_state_dict(best_param) + model.analyze_weights() + with torch.no_grad(): + train_metric(preds, historical_y) + train_results = train_metric.get_info() + print(train_results) + + metric = metric_cls(True) + per_timestamp_time, start_time = AverageMeter(), time.time() + for idx, (future_time, (future_x, future_y)) in enumerate(env): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True) + ) + logger.log( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(env)) + + " " + + need_time + ) + # train the same data + + # build optimizer + xmetric = ComposeMetric(metric_cls(True), SaveMetric()) + future_x.to(args.device), future_y.to(args.device) + future_y_hat = model(future_x) + future_loss = criterion(future_y_hat, future_y) + metric(future_y_hat, future_y) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(env)) + + " train-score: {:.5f}, eval-score: {:.5f}".format( + train_results["score"], metric.get_info()["score"] + ) + ) + logger.log(log_str) + logger.log("") + per_timestamp_time.update(time.time() - start_time) + start_time = time.time() + + save_checkpoint( + {"w_containers": w_containers}, + logger.path(None) / "final-ckp.pth", + logger, + ) + + logger.log("-" * 200 + "\n") + logger.close() + return metric.get_info()["score"] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/GeMOSA-synthetic/use-same-nof-timestamp", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="The batch size", + ) + parser.add_argument( + "--epochs", + type=int, + default=300, + help="The total number of epochs.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( + args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version + ) + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + main(args) + else: + results = [] + for iseed in range(3): + args.rand_seed = random.randint(1, 100000) + result = main(args) + results.append(result) + show_mean_var(result) diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index 077e826..4b6e7ae 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -84,6 +84,14 @@ class SyntheticDEnv(data.Dataset): def mode(self): return self._time_generator.mode + def get_seq_times(self, index, seq_length): + index, timestamp = self._time_generator[index] + xtimes = [] + for i in range(1, seq_length + 1): + xtimes.append(timestamp - i * self.time_interval) + xtimes.reverse() + return xtimes + def get_timestamp(self, index): if index is None: timestamps = []