From 9a2c9fc435bf56ec45e4021e3a2569c9f7fc9358 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 15 May 2021 12:28:41 +0000 Subject: [PATCH] Update LFAN to device --- exps/LFNA/lfna.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 9743678..c088a1b 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -190,7 +190,7 @@ def main(args): ) # creating the new meta-time-embedding - distance = meta_model.get_closest_meta_distance(future_time) + distance = meta_model.get_closest_meta_distance(future_time.item()) if distance < eval_env.timestamp_interval: continue # @@ -198,7 +198,7 @@ def main(args): optimizer = torch.optim.Adam( [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True ) - meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param) + meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param) meta_model.eval() base_model.train() for iepoch in range(args.epochs):