Update LFAN to device
This commit is contained in:
parent
5e766603be
commit
9a2c9fc435
@ -190,7 +190,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# creating the new meta-time-embedding
|
# creating the new meta-time-embedding
|
||||||
distance = meta_model.get_closest_meta_distance(future_time)
|
distance = meta_model.get_closest_meta_distance(future_time.item())
|
||||||
if distance < eval_env.timestamp_interval:
|
if distance < eval_env.timestamp_interval:
|
||||||
continue
|
continue
|
||||||
#
|
#
|
||||||
@ -198,7 +198,7 @@ def main(args):
|
|||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
[new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True
|
||||||
)
|
)
|
||||||
meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param)
|
meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param)
|
||||||
meta_model.eval()
|
meta_model.eval()
|
||||||
base_model.train()
|
base_model.train()
|
||||||
for iepoch in range(args.epochs):
|
for iepoch in range(args.epochs):
|
||||||
|
Loading…
Reference in New Issue
Block a user