From c6db1ef65aaddc4e6b0e0f69150fa55ca804cb6a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 28 May 2021 01:03:02 +0800 Subject: [PATCH] X --- .../{basic-maml.py => baselines/maml-ft.py} | 202 ++++++----- exps/GeMOSA/baselines/maml-nof.py | 317 ++++++++++++++++++ exps/GeMOSA/basic-same.py | 1 - exps/GeMOSA/side_utils.py | 50 --- 4 files changed, 441 insertions(+), 129 deletions(-) rename exps/GeMOSA/{basic-maml.py => baselines/maml-ft.py} (53%) create mode 100644 exps/GeMOSA/baselines/maml-nof.py delete mode 100644 exps/GeMOSA/side_utils.py diff --git a/exps/GeMOSA/basic-maml.py b/exps/GeMOSA/baselines/maml-ft.py similarity index 53% rename from exps/GeMOSA/basic-maml.py rename to exps/GeMOSA/baselines/maml-ft.py index b3fcce3..17c4ef2 100644 --- a/exps/GeMOSA/basic-maml.py +++ b/exps/GeMOSA/baselines/maml-ft.py @@ -1,30 +1,33 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-maml.py --env_version v1 --inner_step 5 -# python exps/LFNA/basic-maml.py --env_version v2 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm from copy import deepcopy from pathlib import Path -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() +print(lib_dir) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time +from xautodl.procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, +) +from xautodl.log_utils import time_string +from xautodl.log_utils import AverageMeter, convert_secs2time -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env, EnvSampler -from models.xcore import get_model -from xlayers import super_core - -from lfna_utils import lfna_setup, TimeData +from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from xautodl.datasets.synthetic_core import get_synthetic_env +from xautodl.models.xcore import get_model +from xautodl.xlayers import super_core class MAML: @@ -34,31 +37,22 @@ class MAML: self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 ): self.criterion = criterion - # self.container = container self.network = network self.meta_optimizer = torch.optim.Adam( self.network.parameters(), lr=meta_lr, amsgrad=True ) - self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - self.meta_optimizer, - milestones=[ - int(epochs * 0.8), - int(epochs * 0.9), - ], - gamma=0.1, - ) self.inner_lr = inner_lr self.inner_step = inner_step self._best_info = dict(state_dict=None, iepoch=None, score=None) print("There are {:} weights.".format(self.network.get_w_container().numel())) - def adapt(self, dataset): + def adapt(self, x, y): # create a container for the future timestamp container = self.network.get_w_container() for k in range(0, self.inner_step): - y_hat = self.network.forward_with_container(dataset.x, container) - loss = self.criterion(y_hat, dataset.y) + y_hat = self.network.forward_with_container(x, container) + loss = self.criterion(y_hat, y) grads = torch.autograd.grad(loss, container.parameters()) container = container.additive([-self.inner_lr * grad for grad in grads]) return container @@ -73,7 +67,6 @@ class MAML: def step(self): torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) self.meta_optimizer.step() - self.meta_lr_scheduler.step() def zero_grad(self): self.meta_optimizer.zero_grad() @@ -82,14 +75,12 @@ class MAML: self.criterion.load_state_dict(state_dict["criterion"]) self.network.load_state_dict(state_dict["network"]) self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) - self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) def state_dict(self): state_dict = dict() state_dict["criterion"] = self.criterion.state_dict() state_dict["network"] = self.network.state_dict() state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() - state_dict["meta_lr_scheduler"] = self.meta_lr_scheduler.state_dict() return state_dict def save_best(self, score): @@ -101,12 +92,39 @@ class MAML: def main(args): - logger, env_info, model_kwargs = lfna_setup(args) + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + train_env = get_synthetic_env(mode="train", version=args.env_version) + valid_env = get_synthetic_env(mode="valid", version=args.env_version) + trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) + test_env = get_synthetic_env(mode="test", version=args.env_version) + all_env = get_synthetic_env(mode=None, version=args.env_version) + logger.log("The training enviornment: {:}".format(train_env)) + logger.log("The validation enviornment: {:}".format(valid_env)) + logger.log("The trainval enviornment: {:}".format(trainval_env)) + logger.log("The total enviornment: {:}".format(all_env)) + logger.log("The test enviornment: {:}".format(test_env)) + model_kwargs = dict( + config=dict(model_type="norm_mlp"), + input_dim=all_env.meta_info["input_dim"], + output_dim=all_env.meta_info["output_dim"], + hidden_dims=[args.hidden_dim] * 2, + act_cls="relu", + norm_cls="layer_norm_1d", + ) + model = get_model(**model_kwargs) - - dynamic_env = get_synthetic_env(mode="train", version=args.env_version) - - criterion = torch.nn.MSELoss() + model = model.to(args.device) + if all_env.meta_info["task"] == "regression": + criterion = torch.nn.MSELoss() + metric_cls = MSEMetric + elif all_env.meta_info["task"] == "classification": + criterion = torch.nn.CrossEntropyLoss() + metric_cls = Top1AccMetric + else: + raise ValueError( + "This task ({:}) is not supported.".format(all_env.meta_info["task"]) + ) maml = MAML( model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step @@ -127,14 +145,16 @@ def main(args): maml.zero_grad() meta_losses = [] for ibatch in range(args.meta_batch): - future_timestamp = dynamic_env.random_timestamp() - _, (future_x, future_y) = dynamic_env(future_timestamp) - past_timestamp = ( - future_timestamp - args.prev_time * dynamic_env.timestamp_interval - ) - _, (past_x, past_y) = dynamic_env(past_timestamp) - - future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) + future_idx = random.randint(0, len(trainval_env) - 1) + future_t, (future_x, future_y) = trainval_env[future_idx] + # -->> + seq_times = trainval_env.get_seq_times(future_idx, args.seq_length) + _, (allxs, allys) = trainval_env.seq_call(seq_times) + allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) + if trainval_env.meta_info["task"] == "classification": + allys = allys.view(-1) + historical_x, historical_y = allxs.to(args.device), allys.to(args.device) + future_container = maml.adapt(historical_x, historical_y) future_y_hat = maml.predict(future_x, future_container) future_loss = maml.criterion(future_y_hat, future_y) meta_losses.append(future_loss) @@ -157,37 +177,67 @@ def main(args): # meta-test maml.load_best() - eval_env = env_info["dynamic_env"] - assert eval_env.timestamp_interval == dynamic_env.timestamp_interval - w_container_per_epoch = dict() - for idx in range(args.prev_time, len(eval_env)): - future_timestamp, (future_x, future_y) = eval_env[idx] - past_timestamp = ( - future_timestamp.item() - args.prev_time * eval_env.timestamp_interval - ) - _, (past_x, past_y) = eval_env(past_timestamp) - future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) - w_container_per_epoch[idx] = future_container.no_grad_clone() + + def finetune(index): + seq_times = test_env.get_seq_times(index, args.seq_length) + _, (allxs, allys) = test_env.seq_call(seq_times) + allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) + if test_env.meta_info["task"] == "classification": + allys = allys.view(-1) + historical_x, historical_y = allxs.to(args.device), allys.to(args.device) + future_container = maml.adapt(historical_x, historical_y) + + historical_y_hat = maml.predict(historical_x, future_container) + train_metric = metric_cls(True) + # model.analyze_weights() with torch.no_grad(): - future_y_hat = maml.predict(future_x, w_container_per_epoch[idx]) - future_loss = maml.criterion(future_y_hat, future_y) - logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) - save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, - logger.path(None) / "final-ckp.pth", - logger, - ) + train_metric(historical_y_hat, historical_y) + train_results = train_metric.get_info() + return train_results, future_container + + train_results, future_container = finetune(0) + + metric = metric_cls(True) + per_timestamp_time, start_time = AverageMeter(), time.time() + for idx, (future_time, (future_x, future_y)) in enumerate(test_env): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True) + ) + logger.log( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(test_env)) + + " " + + need_time + ) + + # build optimizer + future_x.to(args.device), future_y.to(args.device) + future_y_hat = maml.predict(future_x, future_container) + future_loss = criterion(future_y_hat, future_y) + metric(future_y_hat, future_y) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(test_env)) + + " train-score: {:.5f}, eval-score: {:.5f}".format( + train_results["score"], metric.get_info()["score"] + ) + ) + logger.log(log_str) + logger.log("") + per_timestamp_time.update(time.time() - start_time) + start_time = time.time() logger.log("-" * 200 + "\n") logger.close() if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") + parser = argparse.ArgumentParser("Use the maml.") parser.add_argument( "--save_dir", type=str, - default="./outputs/lfna-synthetic/use-maml", + default="./outputs/lfna-synthetic/use-maml-nft", help="The checkpoint directory.", ) parser.add_argument( @@ -205,15 +255,9 @@ if __name__ == "__main__": parser.add_argument( "--meta_lr", type=float, - default=0.01, + default=0.02, help="The learning rate for the MAML optimizer (default is Adam)", ) - parser.add_argument( - "--fail_thresh", - type=float, - default=1000, - help="The threshold for the failure, which we reuse the previous best model", - ) parser.add_argument( "--inner_lr", type=float, @@ -224,15 +268,12 @@ if __name__ == "__main__": "--inner_step", type=int, default=1, help="The inner loop steps for MAML." ) parser.add_argument( - "--prev_time", - type=int, - default=5, - help="The gap between prev_time and current_timestamp", + "--seq_length", type=int, default=20, help="The sequence length." ) parser.add_argument( "--meta_batch", type=int, - default=64, + default=256, help="The batch size for the meta-model", ) parser.add_argument( @@ -247,6 +288,12 @@ if __name__ == "__main__": default=50, help="The maximum epochs for early stop.", ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) parser.add_argument( "--workers", type=int, @@ -259,12 +306,11 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-s{:}-mlr{:}-d{:}-prev{:}-e{:}-env{:}".format( + args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format( args.save_dir, args.inner_step, args.meta_lr, args.hidden_dim, - args.prev_time, args.epochs, args.env_version, ) diff --git a/exps/GeMOSA/baselines/maml-nof.py b/exps/GeMOSA/baselines/maml-nof.py new file mode 100644 index 0000000..17c4ef2 --- /dev/null +++ b/exps/GeMOSA/baselines/maml-nof.py @@ -0,0 +1,317 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 +# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve() +print(lib_dir) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from xautodl.procedures import ( + prepare_seed, + prepare_logger, + save_checkpoint, + copy_checkpoint, +) +from xautodl.log_utils import time_string +from xautodl.log_utils import AverageMeter, convert_secs2time + +from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from xautodl.datasets.synthetic_core import get_synthetic_env +from xautodl.models.xcore import get_model +from xautodl.xlayers import super_core + + +class MAML: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__( + self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1 + ): + self.criterion = criterion + self.network = network + self.meta_optimizer = torch.optim.Adam( + self.network.parameters(), lr=meta_lr, amsgrad=True + ) + self.inner_lr = inner_lr + self.inner_step = inner_step + self._best_info = dict(state_dict=None, iepoch=None, score=None) + print("There are {:} weights.".format(self.network.get_w_container().numel())) + + def adapt(self, x, y): + # create a container for the future timestamp + container = self.network.get_w_container() + + for k in range(0, self.inner_step): + y_hat = self.network.forward_with_container(x, container) + loss = self.criterion(y_hat, y) + grads = torch.autograd.grad(loss, container.parameters()) + container = container.additive([-self.inner_lr * grad for grad in grads]) + return container + + def predict(self, x, container=None): + if container is not None: + y_hat = self.network.forward_with_container(x, container) + else: + y_hat = self.network(x) + return y_hat + + def step(self): + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + + def load_state_dict(self, state_dict): + self.criterion.load_state_dict(state_dict["criterion"]) + self.network.load_state_dict(state_dict["network"]) + self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) + + def state_dict(self): + state_dict = dict() + state_dict["criterion"] = self.criterion.state_dict() + state_dict["network"] = self.network.state_dict() + state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() + return state_dict + + def save_best(self, score): + success, best_score = self.network.save_best(score) + return success, best_score + + def load_best(self): + self.network.load_best() + + +def main(args): + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + train_env = get_synthetic_env(mode="train", version=args.env_version) + valid_env = get_synthetic_env(mode="valid", version=args.env_version) + trainval_env = get_synthetic_env(mode="trainval", version=args.env_version) + test_env = get_synthetic_env(mode="test", version=args.env_version) + all_env = get_synthetic_env(mode=None, version=args.env_version) + logger.log("The training enviornment: {:}".format(train_env)) + logger.log("The validation enviornment: {:}".format(valid_env)) + logger.log("The trainval enviornment: {:}".format(trainval_env)) + logger.log("The total enviornment: {:}".format(all_env)) + logger.log("The test enviornment: {:}".format(test_env)) + model_kwargs = dict( + config=dict(model_type="norm_mlp"), + input_dim=all_env.meta_info["input_dim"], + output_dim=all_env.meta_info["output_dim"], + hidden_dims=[args.hidden_dim] * 2, + act_cls="relu", + norm_cls="layer_norm_1d", + ) + + model = get_model(**model_kwargs) + model = model.to(args.device) + if all_env.meta_info["task"] == "regression": + criterion = torch.nn.MSELoss() + metric_cls = MSEMetric + elif all_env.meta_info["task"] == "classification": + criterion = torch.nn.CrossEntropyLoss() + metric_cls = Top1AccMetric + else: + raise ValueError( + "This task ({:}) is not supported.".format(all_env.meta_info["task"]) + ) + + maml = MAML( + model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step + ) + + # meta-training + last_success_epoch = 0 + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + head_str = ( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + maml.zero_grad() + meta_losses = [] + for ibatch in range(args.meta_batch): + future_idx = random.randint(0, len(trainval_env) - 1) + future_t, (future_x, future_y) = trainval_env[future_idx] + # -->> + seq_times = trainval_env.get_seq_times(future_idx, args.seq_length) + _, (allxs, allys) = trainval_env.seq_call(seq_times) + allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) + if trainval_env.meta_info["task"] == "classification": + allys = allys.view(-1) + historical_x, historical_y = allxs.to(args.device), allys.to(args.device) + future_container = maml.adapt(historical_x, historical_y) + future_y_hat = maml.predict(future_x, future_container) + future_loss = maml.criterion(future_y_hat, future_y) + meta_losses.append(future_loss) + meta_loss = torch.stack(meta_losses).mean() + meta_loss.backward() + maml.step() + + logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item())) + success, best_score = maml.save_best(-meta_loss.item()) + if success: + logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) + save_checkpoint(maml.state_dict(), logger.path("model"), logger) + last_success_epoch = iepoch + if iepoch - last_success_epoch >= args.early_stop_thresh: + logger.log("Early stop at {:}".format(iepoch)) + break + + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + # meta-test + maml.load_best() + + def finetune(index): + seq_times = test_env.get_seq_times(index, args.seq_length) + _, (allxs, allys) = test_env.seq_call(seq_times) + allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1) + if test_env.meta_info["task"] == "classification": + allys = allys.view(-1) + historical_x, historical_y = allxs.to(args.device), allys.to(args.device) + future_container = maml.adapt(historical_x, historical_y) + + historical_y_hat = maml.predict(historical_x, future_container) + train_metric = metric_cls(True) + # model.analyze_weights() + with torch.no_grad(): + train_metric(historical_y_hat, historical_y) + train_results = train_metric.get_info() + return train_results, future_container + + train_results, future_container = finetune(0) + + metric = metric_cls(True) + per_timestamp_time, start_time = AverageMeter(), time.time() + for idx, (future_time, (future_x, future_y)) in enumerate(test_env): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True) + ) + logger.log( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(test_env)) + + " " + + need_time + ) + + # build optimizer + future_x.to(args.device), future_y.to(args.device) + future_y_hat = maml.predict(future_x, future_container) + future_loss = criterion(future_y_hat, future_y) + metric(future_y_hat, future_y) + log_str = ( + "[{:}]".format(time_string()) + + " [{:04d}/{:04d}]".format(idx, len(test_env)) + + " train-score: {:.5f}, eval-score: {:.5f}".format( + train_results["score"], metric.get_info()["score"] + ) + ) + logger.log(log_str) + logger.log("") + per_timestamp_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the maml.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/use-maml-nft", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=16, + help="The hidden dimension.", + ) + parser.add_argument( + "--meta_lr", + type=float, + default=0.02, + help="The learning rate for the MAML optimizer (default is Adam)", + ) + parser.add_argument( + "--inner_lr", + type=float, + default=0.005, + help="The learning rate for the inner optimization", + ) + parser.add_argument( + "--inner_step", type=int, default=1, help="The inner loop steps for MAML." + ) + parser.add_argument( + "--seq_length", type=int, default=20, help="The sequence length." + ) + parser.add_argument( + "--meta_batch", + type=int, + default=256, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--epochs", + type=int, + default=2000, + help="The total number of epochs.", + ) + parser.add_argument( + "--early_stop_thresh", + type=int, + default=50, + help="The maximum epochs for early stop.", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format( + args.save_dir, + args.inner_step, + args.meta_lr, + args.hidden_dim, + args.epochs, + args.env_version, + ) + main(args) diff --git a/exps/GeMOSA/basic-same.py b/exps/GeMOSA/basic-same.py index 25cd445..5e7f739 100644 --- a/exps/GeMOSA/basic-same.py +++ b/exps/GeMOSA/basic-same.py @@ -28,7 +28,6 @@ from xautodl.log_utils import AverageMeter, convert_secs2time from xautodl.utils import split_str2indexes -from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn from xautodl.procedures.metric_utils import ( SaveMetric, MSEMetric, diff --git a/exps/GeMOSA/side_utils.py b/exps/GeMOSA/side_utils.py deleted file mode 100644 index a9fe522..0000000 --- a/exps/GeMOSA/side_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -import copy -import torch -from tqdm import tqdm -from xautodl.procedures import prepare_seed, prepare_logger -from xautodl.datasets.synthetic_core import get_synthetic_env - - -def train_model(model, dataset, lr, epochs): - criterion = torch.nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True) - best_loss, best_param = None, None - for _iepoch in range(epochs): - preds = model(dataset.x) - optimizer.zero_grad() - loss = criterion(preds, dataset.y) - loss.backward() - optimizer.step() - # save best - if best_loss is None or best_loss > loss.item(): - best_loss = loss.item() - best_param = copy.deepcopy(model.state_dict()) - model.load_state_dict(best_param) - return best_loss - - -class TimeData: - def __init__(self, timestamp, xs, ys): - self._timestamp = timestamp - self._xs = xs - self._ys = ys - - @property - def x(self): - return self._xs - - @property - def y(self): - return self._ys - - @property - def timestamp(self): - return self._timestamp - - def __repr__(self): - return "{name}(timestamp={timestamp}, with {num} samples)".format( - name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) - )