Update LFAN to device
This commit is contained in:
		| @@ -190,7 +190,7 @@ def main(args): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # creating the new meta-time-embedding |         # creating the new meta-time-embedding | ||||||
|         distance = meta_model.get_closest_meta_distance(future_time) |         distance = meta_model.get_closest_meta_distance(future_time.item()) | ||||||
|         if distance < eval_env.timestamp_interval: |         if distance < eval_env.timestamp_interval: | ||||||
|             continue |             continue | ||||||
|         # |         # | ||||||
| @@ -198,7 +198,7 @@ def main(args): | |||||||
|         optimizer = torch.optim.Adam( |         optimizer = torch.optim.Adam( | ||||||
|             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True |             [new_param], lr=args.init_lr, weight_decay=1e-5, amsgrad=True | ||||||
|         ) |         ) | ||||||
|         meta_model.replace_append_learnt(torch.Tensor([future_time]), new_param) |         meta_model.replace_append_learnt(torch.Tensor([future_time]).to(args.device), new_param) | ||||||
|         meta_model.eval() |         meta_model.eval() | ||||||
|         base_model.train() |         base_model.train() | ||||||
|         for iepoch in range(args.epochs): |         for iepoch in range(args.epochs): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user