Update LFAN to device

This commit is contained in:
D-X-Y 2021-05-15 12:28:41 +00:00
parent 5e766603be
commit 9a2c9fc435

View File

@ -190,7 +190,7 @@ def main(args):
) )
# creating the new meta-time-embedding # creating the new meta-time-embedding
distance = meta_model.get_closest_meta_distance(future_time) distance = meta_model.get_closest_meta_distance(future_time.item())
if distance < eval_env.timestamp_interval: if distance < eval_env.timestamp_interval:
continue continue
# #
@ -198,7 +198,7 @@ def main(args):
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
) )
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param) meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param)
meta_model.eval() meta_model.eval()
base_model.train() base_model.train()
for iepoch in range(args.epochs): for iepoch in range(args.epochs):