From da4b61f3abbd7cf961d9ab8a27d5271e6a01c44a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 23 May 2021 19:26:09 +0000 Subject: [PATCH] Update LFNA -- refine --- exps/LFNA/lfna.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index c46a083..f0ec021 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): return loss_meter -def online_evaluate(env, meta_model, base_model, criterion, args, logger): +def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): logger.log("Online evaluate: {:}".format(env)) + loss_meter = AverageMeter() + w_containers = dict() for idx, (future_time, (future_x, future_y)) in enumerate(env): with torch.no_grad(): meta_model.eval() @@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): _, [future_container], time_embeds = meta_model( future_time.to(args.device).view(1, 1), None, True ) + if save: + w_containers[idx] = future_container.no_grad_clone() future_x, future_y = future_x.to(args.device), future_y.to(args.device) future_y_hat = base_model.forward_with_container(future_x, future_container) future_loss = criterion(future_y_hat, future_y) + loss_meter.update(future_loss.item()) refine, post_refine_loss = meta_model.adapt( base_model, criterion, @@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): ) meta_model.clear_fixed() meta_model.clear_learnt() + return w_containers, loss_meter def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): @@ -219,8 +225,10 @@ def main(args): logger, env_info, model_kwargs = lfna_setup(args) train_env = get_synthetic_env(mode="train", version=args.env_version) valid_env = get_synthetic_env(mode="valid", version=args.env_version) - logger.log("training enviornment: {:}".format(train_env)) - logger.log("validation enviornment: {:}".format(valid_env)) + all_env = get_synthetic_env(mode=None, version=args.env_version) + logger.log("The training enviornment: {:}".format(train_env)) + logger.log("The validation enviornment: {:}".format(valid_env)) + logger.log("The total enviornment: {:}".format(all_env)) base_model = get_model(**model_kwargs) base_model = base_model.to(args.device) @@ -249,10 +257,20 @@ def main(args): pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) # try to evaluate once - online_evaluate(train_env, meta_model, base_model, criterion, args, logger) - online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) + # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) + # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) + w_containers, loss_meter = online_evaluate( + all_env, meta_model, base_model, criterion, args, logger, True + ) + logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) - pdb.set_trace() + save_checkpoint( + {"w_containers": w_containers}, + logger.path(None) / "final-ckp.pth", + logger, + ) + return + """ optimizer = torch.optim.Adam( meta_model.get_parameters(True, True, False), # fix hypernet lr=args.lr, @@ -364,7 +382,6 @@ def main(args): # meta-test meta_model.load_best() eval_env = env_info["dynamic_env"] - w_container_per_epoch = dict() for idx in range(args.seq_length, len(eval_env)): # build-timestamp future_time = env_info["{:}-timestamp".format(idx)].item() @@ -424,6 +441,7 @@ def main(args): logger.path(None) / "final-ckp.pth", logger, ) + """ logger.log("-" * 200 + "\n") logger.close() @@ -494,7 +512,7 @@ if __name__ == "__main__": help="The learning rate for the optimizer, during refine", ) parser.add_argument( - "--refine_epochs", type=int, default=40, help="The final refine #epochs." + "--refine_epochs", type=int, default=50, help="The final refine #epochs." ) parser.add_argument( "--early_stop_thresh",