diff --git a/exps/LFNA/lfna-tall-hpnet.py b/exps/LFNA/lfna-tall-hpnet.py index 99f8f49..162a1cc 100644 --- a/exps/LFNA/lfna-tall-hpnet.py +++ b/exps/LFNA/lfna-tall-hpnet.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 +# python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -42,7 +42,7 @@ def main(args): hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) total_bar = env_info["total"] - 1 task_embeds = [] - for i in range(total_bar): + for i in range(env_info["total"]): task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) for task_embed in task_embeds: trunc_normal_(task_embed, std=0.02) @@ -97,7 +97,7 @@ def main(args): if iepoch % 200 == 0: logger.log( head_str - + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( + + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( loss_meter.avg, loss_meter.val, min(lr_scheduler.get_last_lr()), @@ -109,7 +109,7 @@ def main(args): save_checkpoint( { "hypernet": hypernet.state_dict(), - "task_embed": task_embed, + "task_embeds": task_embeds, "lr_scheduler": lr_scheduler.state_dict(), "iepoch": iepoch, }, @@ -122,6 +122,25 @@ def main(args): print(model) print(hypernet) + w_container_per_epoch = dict() + for idx in range(0, env_info["total"]): + future_time = env_info["{:}-timestamp".format(idx)] + future_x = env_info["{:}-x".format(idx)] + future_y = env_info["{:}-y".format(idx)] + future_container = hypernet(task_embeds[idx]) + w_container_per_epoch[idx] = future_container.no_grad_clone() + with torch.no_grad(): + future_y_hat = model.forward_with_container( + future_x, w_container_per_epoch[idx] + ) + future_loss = criterion(future_y_hat, future_y) + logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) + + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) logger.log("-" * 200 + "\n") logger.close() diff --git a/exps/LFNA/lfna-test-hpnet.py b/exps/LFNA/lfna-test-hpnet.py index ce7715d..b4aa9c1 100644 --- a/exps/LFNA/lfna-test-hpnet.py +++ b/exps/LFNA/lfna-test-hpnet.py @@ -34,17 +34,20 @@ def main(args): logger, env_info, model_kwargs = lfna_setup(args) dynamic_env = env_info["dynamic_env"] model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + model = model.to(args.device) criterion = torch.nn.MSELoss() logger.log("There are {:} weights.".format(model.get_w_container().numel())) shape_container = model.get_w_container().to_shape_container() hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) + hypernet = hypernet.to(args.device) # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) total_bar = 10 task_embeds = [] for i in range(total_bar): - task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) + tensor = torch.Tensor(1, args.task_dim).to(args.device) + task_embeds.append(torch.nn.Parameter(tensor)) for task_embed in task_embeds: trunc_normal_(task_embed, std=0.02) @@ -79,8 +82,8 @@ def main(args): # cur_time = random.randint(0, total_bar) cur_task_embed = task_embeds[cur_time] cur_container = hypernet(cur_task_embed) - cur_x = env_info["{:}-x".format(cur_time)] - cur_y = env_info["{:}-y".format(cur_time)] + cur_x = env_info["{:}-x".format(cur_time)].to(args.device) + cur_y = env_info["{:}-y".format(cur_time)].to(args.device) cur_dataset = TimeData(cur_time, cur_x, cur_y) preds = model.forward_with_container(cur_dataset.x, cur_container) @@ -98,7 +101,7 @@ def main(args): if iepoch % 200 == 0: logger.log( head_str - + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( + + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( loss_meter.avg, loss_meter.val, min(lr_scheduler.get_last_lr()), @@ -166,6 +169,12 @@ if __name__ == "__main__": default=2000, help="The total number of epochs.", ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="", + ) # Random Seed parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") args = parser.parse_args()